Skip to content

changzy00/pytorch-attention

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

image

This codebase is a PyTorch implementation of various attention mechanisms, CNNs, Vision Transformers and MLP-Like models.

If it is helpful for your work, please⭐

Updating...

Install

git clone https://github.com/changzy00/pytorch-attention.git
cd pytorch-attention

Content

Attention Mechanisms

1. Squeeze-and-Excitation Attention

  • Squeeze-and-Excitation Networks (CVPR 2018) pdf

  • Model Overview

  • Code
import torch
from attention_mechanisms.se_module import SELayer

x = torch.randn(2, 64, 32, 32)
attn = SELayer(64)
y = attn(x)
print(y.shape)

2. Convolutional Block Attention Module

  • CBAM: convolutional block attention module (ECCV 2018) pdf

  • Model Overview

  • Code
import torch
from attention_mechanisms.cbam import CBAM

x = torch.randn(2, 64, 32, 32)
attn = CBAM(64)
y = attn(x)
print(y.shape)

3. Bottleneck Attention Module

  • Bam: Bottleneck attention module(BMVC 2018) pdf

  • Model Overview

  • Code
import torch
from attention_mechanisms.bam import BAM

x = torch.randn(2, 64, 32, 32)
attn = BAM(64)
y = attn(x)
print(y.shape)

4. Double Attention

  • A2-nets: Double attention networks (NeurIPS 2018) pdf

  • Model Overview

  • Code
import torch
from attention_mechanisms.double_attention import DoubleAttention

x = torch.randn(2, 64, 32, 32)
attn = DoubleAttention(64, 32, 32)
y = attn(x)
print(y.shape)

5. Style Attention

  • Srm : A style-based recalibration module for convolutional neural networks (ICCV 2019) pdf

  • Model Overview

  • Code
import torch
from attention_mechanisms.srm import SRM

x = torch.randn(2, 64, 32, 32)
attn = SRM(64)
y = attn(x)
print(y.shape)

6. Global Context Attention

  • Gcnet: Non-local networks meet squeeze-excitation networks and beyond (ICCVW 2019) pdf

  • Model Overview

  • Code
import torch
from attention_mechanisms.gc_module import GCModule

x = torch.randn(2, 64, 32, 32)
attn = GCModule(64)
y = attn(x)
print(y.shape)

7. Selective Kernel Attention

  • Selective Kernel Networks (CVPR 2019) pdf

  • Model Overview

  • Code
import torch
from attention_mechanisms.sk_module import SKLayer

x = torch.randn(2, 64, 32, 32)
attn = SKLayer(64)
y = attn(x)
print(y.shape)

8. Linear Context Attention

  • Linear Context Transform Block (AAAI 2020) pdf

  • Model Overview

  • Code
import torch
from attention_mechanisms.lct import LCT

x = torch.randn(2, 64, 32, 32)
attn = LCT(64, groups=8)
y = attn(x)
print(y.shape)

9. Gated Channel Attention

  • Gated Channel Transformation for Visual Recognition (CVPR 2020) pdf

  • Model Overview

  • Code
import torch
from attention_mechanisms.gate_channel_module import GCT

x = torch.randn(2, 64, 32, 32)
attn = GCT(64)
y = attn(x)
print(y.shape)

10. Efficient Channel Attention

  • Ecanet: Efficient channel attention for deep convolutional neural networks (CVPR 2020) pdf

  • Model Overview

  • Code
import torch
from attention_mechanisms.eca import ECALayer

x = torch.randn(2, 64, 32, 32)
attn = ECALayer(64)
y = attn(x)
print(y.shape)

11. Triplet Attention

  • Rotate to Attend: Convolutional Triplet Attention Module (WACV 2021) pdf

  • Model Overview

  • Code
import torch
from attention_mechanisms.triplet_attention import TripletAttention

x = torch.randn(2, 64, 32, 32)
attn = TripletAttention(64)
y = attn(x)
print(y.shape)

12. Gaussian Context Attention

  • Gaussian Context Transformer (CVPR 2021) pdf

  • Model Overview

  • Code
import torch
from attention_mechanisms.gct import GCT

x = torch.randn(2, 64, 32, 32)
attn = GCT(64)
y = attn(x)
print(y.shape)

13. Coordinate Attention

  • Coordinate Attention for Efficient Mobile Network Design (CVPR 2021) pdf

  • Model Overview

  • Code
import torch
from attention_mechanisms.coordatten import CoordinateAttention

x = torch.randn(2, 64, 32, 32)
attn = CoordinateAttention(64, 64)
y = attn(x)
print(y.shape)

14. SimAM

  • SimAM: A Simple, Parameter-Free Attention Module for Convolutional Neural Networks (ICML 2021) pdf
  • Model Overview

  • Code
import torch
from attention_mechanisms.simam import simam_module

x = torch.randn(2, 64, 32, 32)
attn = simam_module(64)
y = attn(x)
print(y.shape)

15. Dual Attention

  • Dual Attention Network for Scene Segmentatio (CVPR 2019) pdf

  • Model Overview

  • Code
import torch
from attention_mechanisms.dual_attention import PAM, CAM

x = torch.randn(2, 64, 32, 32)
#attn = PAM(64)
attn = CAM()
y = attn(x)
print(y.shape

Vision Transformers

1. ViT Model

  • An image is worth 16x16 words: Transformers for image recognition at scale (ICLR 2021) pdf

  • Model Overview

  • Code
import torch
from vision_transformers.ViT import VisionTransformer

x = torch.randn(2, 3, 224, 224)
model = VisionTransformer()
y = model(x)
print(y.shape) #[2, 1000]

2. XCiT Model

  • XCiT: Cross-Covariance Image Transformer (NeurIPS 2021) pdf

  • Model Overview

  • Code
import torch
from vision_transformers.xcit import xcit_nano_12_p16
x = torch.randn(2, 3, 224, 224)
model = xcit_nano_12_p16()
y = model(x)
print(y.shape)

3. PiT Model

  • Rethinking Spatial Dimensions of Vision Transformers (ICCV 2021) pdf

  • Model Overview

  • Code
import torch
from vision_transformers.pit import pit_ti
x = torch.randn(2, 3, 224, 224)
model = pit_ti()
y = model(x)
print(y.shape)

4. CvT Model

  • CvT: Introducing Convolutions to Vision Transformers (ICCV 2021) pdf

  • Model Overview

  • Code
import torch
from vision_transformers.cvt import cvt_13
x = torch.randn(2, 3, 224, 224)
model = cvt_13()
y = model(x)
print(y.shape)

5. PvT Model

  • Pyramid vision transformer: A versatile backbone for dense prediction without convolutions (ICCV 2021) pdf

  • Model Overview

  • Code
import torch
from vision_transformers.pvt import pvt_t
x = torch.randn(2, 3, 224, 224)
model = pvt_t()
y = model(x)
print(y.shape)

6. CMT Model

  • CMT: Convolutional Neural Networks Meet Vision Transformers (CVPR 2022) pdf

  • Model Overview

  • Code
import torch
from vision_transformers.cmt import cmt_ti
x = torch.randn(2, 3, 224, 224)
model = cmt_ti()
y = model(x)
print(y.shape)

7. PoolFormer Model

  • MetaFormer is Actually What You Need for Vision (CVPR 2022) pdf

  • Model Overview

  • Code
import torch
from vision_transformers.poolformer import poolformer_12
x = torch.randn(2, 3, 224, 224)
model = poolformer_12()
y = model(x)
print(y.shape)

8. KVT Model

  • KVT: k-NN Attention for Boosting Vision Transformers (ECCV 2022) pdf

  • Code
import torch
from vision_transformers.kvt import KVT
x = torch.randn(2, 3, 224, 224)
model = KVT()
y = model(x)
print(y.shape)

9. MobileViT Model

  • MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer (ICLR 2022) pdf

  • Model Overview

  • Code
import torch
from vision_transformers.mobilevit import mobilevit_s
x = torch.randn(2, 3, 224, 224)
model = mobilevit_s()
y = model(x)
print(y.shape)

10. P2T Model

  • Pyramid Pooling Transformer for Scene Understanding (TPAMI 2022) pdf

  • Model Overview

  • Code
import torch
from vision_transformers.p2t import p2t_tiny
x = torch.randn(2, 3, 224, 224)
model = p2t_tiny()
y = model(x)
print(y.shape)

11. EfficientFormer Model

  • EfficientFormer: Vision Transformers at MobileNet Speed (NeurIPS 2022) pdf

  • Model Overview

  • Code
import torch
from vision_transformers.efficientformer import efficientformer_l1
x = torch.randn(2, 3, 224, 224)
model = efficientformer_l1()
y = model(x)
print(y.shape)

12. ShiftViT Model

  • When Shift Operation Meets Vision Transformer: An Extremely Simple Alternative to Attention Mechanism (AAAI 2022) pdf

  • Model Overview

  • Code
import torch
from vision_transformers.shiftvit import shift_t
x = torch.randn(2, 3, 224, 224)
model = shift_t()
y = model(x)
print(y.shape)

13. CSWin Model

  • CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows (CVPR 2022) pdf

  • Model Overview

  • Code
import torch
from vision_transformers.cswin import CSWin_64_12211_tiny_224
x = torch.randn(2, 3, 224, 224)
model = CSWin_64_12211_tiny_224()
y = model(x)
print(y.shape)

14. DilateFormer Model

  • DilateFormer: Multi-Scale Dilated Transformer for Visual Recognition (TMM 2023) pdf

  • Model Overview

  • Code
import torch
from vision_transformers.dilateformer import dilateformer_tiny
x = torch.randn(2, 3, 224, 224)
model = dilateformer_tiny()
y = model(x)
print(y.shape)

15. BViT Model

  • BViT: Broad Attention based Vision Transformer (TNNLS 2023) pdf

  • Model Overview

  • Code
import torch
from vision_transformers.bvit import BViT_S
x = torch.randn(2, 3, 224, 224)
model = BViT_S()
y = model(x)
print(y.shape)

16. MOAT Model

  • MOAT: Alternating Mobile Convolution and Attention Brings Strong Vision Models (ICLR 2023) pdf

  • Model Overview

  • Code
import torch
from vision_transformers.moat import moat_0
x = torch.randn(2, 3, 224, 224)
model = moat_0()
y = model(x)
print(y.shape)

17. SegFormer Model

  • SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers (NeurIPS 2021) pdf

  • Model Overview

  • Code
import torch
from vision_transformers.moat import SegFormer
x = torch.randn(2, 3, 512, 512)
model = SegFormer(num_classes=50)
y = model(x)
print(y.shape)

18. SETR Model

  • Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformers (CVPR 2021) pdf

  • Model Overview

  • Code
import torch
from vision_transformers.setr import SETR
x = torch.randn(2, 3, 480, 480)
model = SETR(num_classes=50)
y = model(x)
print(y.shape)

Convolutional Neural Networks(CNNs)

1. NiN Model

  • Network In Network (ICLR 2014) pdf

  • Model Overview

  • Code
import torch
from cnns.NiN import NiN 
x = torch.randn(2, 3, 224, 224)
model = NiN()
y = model(x)
print(y.shape)

2. ResNet Model

  • Deep Residual Learning for Image Recognition (CVPR 2016) pdf

  • Model Overview

  • Code
import torch
from cnns.resnet import resnet18 
x = torch.randn(2, 3, 224, 224)
model = resnet18()
y = model(x)
print(y.shape)

3. WideResNet Model

  • Wide Residual Networks (BMVC 2016) pdf

  • Model Overview

  • Code
import torch
from cnns.wideresnet import wideresnet
x = torch.randn(2, 3, 224, 224)
model = wideresnet()
y = model(x)
print(y.shape)

4. DenseNet Model

  • Densely Connected Convolutional Networks (CVPR 2017) pdf

  • Model Overview

  • Code
import torch
from cnns.densenet import densenet121
x = torch.randn(2, 3, 224, 224)
model = densenet121()
y = model(x)
print(y.shape)

5. PyramidNet Model

  • Deep Pyramidal Residual Networks (CVPR 2017) pdf

  • Model Overview

  • Code
import torch
from cnns.pyramidnet import pyramidnet18
x = torch.randn(2, 3, 224, 224)
model = densenet121()
y = model(x)
print(y.shape)

6. MobileNetV1 Model

  • MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications (CVPR 2017) pdf

  • Model Overview

  • Code
import torch
from cnns.mobilenetv1 import MobileNetv1
x = torch.randn(2, 3, 224, 224)
model = MobileNetv1()
y = model(x)
print(y.shape)

7. MobileNetV2 Model

  • MobileNetV2: Inverted Residuals and Linear Bottlenecks (CVPR 2018) pdf

  • Model Overview

  • Code
import torch
from cnns.mobilenetv2 import MobileNetv2
x = torch.randn(2, 3, 224, 224)
model = MobileNetv2()
y = model(x)
print(y.shape)

8. MobileNetV3 Model

  • Searching for MobileNetV3 (ICCV 2019) pdf

  • Model Overview

  • Code
import torch
from cnns.mobilenetv3 import mobilenetv3_small
x = torch.randn(2, 3, 224, 224)
model = mobilenetv3_small()
y = model(x)
print(y.shape)

9. MnasNet Model

  • MnasNet: Platform-Aware Neural Architecture Search for Mobile (CVPR 2019) pdf

  • Model Overview

  • Code
import torch
from cnns.mnasnet import MnasNet
x = torch.randn(2, 3, 224, 224)
model = MnasNet()
y = model(x)
print(y.shape)

10. EfficientNetV1 Model

  • EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks (ICML 2019) pdf

  • Model Overview

  • Code
import torch
from cnns.efficientnet import EfficientNet
x = torch.randn(2, 3, 224, 224)
model = EfficientNet()
y = model(x)
print(y.shape)

11. Res2Net Model

  • Res2Net: A New Multi-scale Backbone Architecture (TPAMI 2019) pdf

  • Model Overview

  • Code
import torch
from cnns.res2net import res2net50
x = torch.randn(2, 3, 224, 224)
model = res2net50()
y = model(x)
print(y.shape)

12. MobileNeXt Model

  • Rethinking Bottleneck Structure for Efficient Mobile Network Design (ECCV 2020) pdf

  • Model Overview

  • Code
import torch
from cnns.mobilenext import MobileNeXt
x = torch.randn(2, 3, 224, 224)
model = MobileNeXt()
y = model(x)
print(y.shape)

13. GhostNet Model

  • GhostNet: More Features from Cheap Operations (CVPR 2020) pdf

  • Model Overview

  • Code
import torch
from cnns.ghostnet import ghostnet
x = torch.randn(2, 3, 224, 224)
model = ghostnet()
y = model(x)
print(y.shape)

14. EfficientNetV2 Model

  • EfficientNetV2: Smaller Models and Faster Trainin (ICML 2021) pdf

  • Model Overview

  • Code
import torch
from cnns.efficientnet import EfficientNetV2
x = torch.randn(2, 3, 224, 224)
model = EfficientNetV2()
y = model(x)
print(y.shape)

15. ConvNeXt Model

  • A ConvNet for the 2020s (CVPR 2022) pdf

  • Model Overview

  • Code
import torch
from cnns.convnext import convnext_18
x = torch.randn(2, 3, 224, 224)
model = convnext_18()
y = model(x)
print(y.shape)

16. Unet Model

  • U-Net: Convolutional Networks for Biomedical Image Segmentation (MICCAI 2015) pdf

  • Model Overview

  • Code
import torch
from cnns.unet import Unet
x = torch.randn(2, 3, 512, 512)
model = Unet(10)
y = model(x)
print(y.shape)

17. ESPNet Model

  • ESPNet: Efficient Spatial Pyramid of Dilated Convolutions for Semantic Segmentation (ECCV 2018) pdf

  • Model Overview

  • Code
import torch
from cnns.espnet import ESPNet
x = torch.randn(2, 3, 512, 512)
model = ESPNet(10)
y = model(x)
print(y.shape)

MLP-Like Models

1. MLP-Mixer Model

  • MLP-Mixer: An all-MLP Architecture for Vision (NeurIPS 2021) pdf

  • Model Overview

  • Code
import torch
from mlps.mlp_mixer import MLP_Mixer
x = torch.randn(2, 3, 224, 224)
model = MLP_Mixer()
y = model(x)
print(y.shape)

2. gMLP Model

  • Pay Attention to MLPs (NeurIPS 2021) pdf

  • Model Overview

  • Code
import torch
from mlps.gmlp import gMLP
x = torch.randn(2, 3, 224, 224)
model = gMLP()
y = model(x)
print(y.shape)

3. GFNet Model

  • Global Filter Networks for Image Classification (NeurIPS 2021) pdf

  • Model Overview

  • Code
import torch
from mlps.gfnet import GFNet
x = torch.randn(2, 3, 224, 224)
model = GFNet()
y = model(x)
print(y.shape)

4. sMLP Model

  • Sparse MLP for Image Recognition: Is Self-Attention Really Necessary? (AAAI 2022) pdf

  • Model Overview

  • Code
import torch
from mlps.smlp import sMLPNet
x = torch.randn(2, 3, 224, 224)
model = sMLPNet()
y = model(x)
print(y.shape)

5. DynaMixer Model

  • DynaMixer: A Vision MLP Architecture with Dynamic Mixing (ICML 2022) pdf

  • Model Overview

  • Code
import torch
from mlps.dynamixer import DynaMixer
x = torch.randn(2, 3, 224, 224)
model = DynaMixer()
y = model(x)
print(y.shape)

6. ConvMixer Model

  • Patches Are All You Need? (TMLR 2022) pdf

  • Model Overview

  • Code
import torch
from mlps.convmixer import ConvMixer
x = torch.randn(2, 3, 224, 224)
model = ConvMixer(128, 6)
y = model(x)
print(y.shape)

7. ViP Model

  • Vision Permutator: A Permutable MLP-Like Architecture for Visual Recognition (TPAMI 2022) pdf

  • Model Overview

  • Code
import torch
from mlps.vip import vip_s7
x = torch.randn(2, 3, 224, 224)
model = vip_s7()
y = model(x)
print(y.shape)

8. CycleMLP Model

  • CycleMLP: A MLP-like Architecture for Dense Prediction (ICLR 2022) pdf

  • Model Overview

  • Code
import torch
from mlps.cyclemlp import CycleMLP_B1
x = torch.randn(2, 3, 224, 224)
model = CycleMLP_B1()
y = model(x)
print(y.shape)

9. Sequencer Model

  • Sequencer: Deep LSTM for Image Classification (NeurIPS 2022) pdf

  • Model Overview

  • Code
import torch
from mlps.sequencer import sequencer_s
x = torch.randn(2, 3, 224, 224)
model = sequencer_s()
y = model(x)
print(y.shape)

10. MobileViG Model

  • MobileViG: Graph-Based Sparse Attention for Mobile Vision Applications (CVPRW 2023) pdf

  • Model Overview

  • Code
import torch
from mlps.mobilevig import mobilevig_s
x = torch.randn(2, 3, 224, 224)
model = mobilevig_s()
y = model(x)
print(y.shape)

About

🦖Pytorch implementation of popular Attention Mechanisms, Vision Transformers, MLP-Like models and CNNs.🔥🔥🔥

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages