In [33]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from vit_pytorch.efficient import ViT as ViT_efficient
from vit_pytorch import ViT

from performer_pytorch import Performer
from linformer.linformer import Linformer, LinformerSelfAttention, FeedForward_linformer, PreNorm
from linformer.reversible import ReversibleSequence, SequentialSequence
from nystrom_attention import Nystromformer, NystromAttention, FeedForward_nystromformer
from reformer_pytorch import Reformer
from transformers import LongformerModel, LongformerConfig

from torch import Tensor
import time
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Subset

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

ImportError: cannot import name 'FeedForward_linformer' from 'linformer.linformer' (/home/mikesmac/anaconda3/envs/vit/lib/python3.8/site-packages/linformer/linformer.py)

In [20]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [2]:
num_classes = 101
image_size = 256
patch_size = 8
d_channel = 512

# transformer

In [22]:
model = ViT(
    dim = d_channel,
    mlp_dim = d_channel,
    image_size = image_size,
    patch_size = patch_size,
    num_classes = num_classes,
    depth = 8,
    heads = 8,
    dropout = 0.,
    emb_dropout = 0.
)
print(f"Total number of parameters in the model: {count_parameters(model)}")
model

Total number of parameters in the model: 13289957


ViT(
  (to_patch_embedding): Sequential(
    (0): Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=8, p2=8)
    (1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
    (2): Linear(in_features=192, out_features=512, bias=True)
    (3): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (dropout): Dropout(p=0.0, inplace=False)
  (transformer): Transformer(
    (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (layers): ModuleList(
      (0-7): 8 x ModuleList(
        (0): Attention(
          (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (attend): Softmax(dim=-1)
          (dropout): Dropout(p=0.0, inplace=False)
          (to_qkv): Linear(in_features=512, out_features=1536, bias=False)
          (to_out): Sequential(
            (0): Linear(in_features=512, out_features=512, bias=True)
            (1): Dropout(p=0.0, inplace=False)
          )
        )
        (1): FeedForward(
          (net): Sequential(
            (0):

# Performer

In [23]:
transformer = Performer(
    dim = d_channel,
    depth = 8,
    heads = 8,
    causal = False,
    dim_head = 64,
    ff_mult = 1,
    generalized_attention = True,
    kernel_fn = nn.ReLU(),
    nb_features = 0, # if nb_features is 0, then use None as projection_matrix in generalized kernel function \
                                    # which means using determinisitc feature projection \
                                    # you need to first cd to "~/anaconda3/envs/vit/lib/python3.8/site-packages/performer_pytorch" \
                                    # edit "performer_pytorch.py": add "if nb_full_blocks == 0: return None" after Line 143
    feature_redraw_interval = 0
)

model = ViT_efficient(
    dim = d_channel,
    image_size = image_size,
    patch_size = patch_size,
    num_classes = num_classes,
    transformer = transformer
)
print(f"Total number of parameters in the model: {count_parameters(model)}")
model

Total number of parameters in the model: 13302245


ViT(
  (to_patch_embedding): Sequential(
    (0): Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=8, p2=8)
    (1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
    (2): Linear(in_features=192, out_features=512, bias=True)
    (3): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (transformer): Performer(
    (net): SequentialSequence(
      (layers): ModuleList(
        (0-7): 8 x ModuleList(
          (0): PreLayerNorm(
            (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
            (fn): SelfAttention(
              (fast_attention): FastAttention(
                (kernel_fn): ReLU()
              )
              (to_q): Linear(in_features=512, out_features=512, bias=True)
              (to_k): Linear(in_features=512, out_features=512, bias=True)
              (to_v): Linear(in_features=512, out_features=512, bias=True)
              (to_out): Linear(in_features=512, out_features=512, bias=True)
              (dropout): Dropout(p=0.0,

# Linformer

In [34]:
from models.linformer import Linformer
transformer = Linformer(
    dim = d_channel,
    seq_len = int((image_size/patch_size)**2+1),
    depth = 8,
    heads = 8,
    k = 64,
    dim_head = 64,
    one_kv_head = True,
    share_kv = True
    )

model = ViT_efficient(
    dim = d_channel,
    image_size = image_size,
    patch_size = patch_size,
    num_classes = num_classes,
    transformer = transformer
)
print(f"Total number of parameters in the model: {count_parameters(model)}")
model

Total number of parameters in the model: 14080997


ViT(
  (to_patch_embedding): Sequential(
    (0): Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=8, p2=8)
    (1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
    (2): Linear(in_features=192, out_features=512, bias=True)
    (3): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (transformer): Linformer(
    (net): SequentialSequence(
      (layers): ModuleList(
        (0-7): 8 x ModuleList(
          (0): PreNorm(
            (fn): LinformerSelfAttention(
              (to_q): Linear(in_features=512, out_features=512, bias=False)
              (to_k): Linear(in_features=512, out_features=64, bias=False)
              (dropout): Dropout(p=0.0, inplace=False)
              (to_out): Linear(in_features=512, out_features=512, bias=True)
            )
            (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          )
          (1): PreNorm(
            (fn): FeedForward(
              (w1): Linear(in_features=512, out_features=1024, bias=True

# Nystromformer

In [38]:
from models.nystromformer import Nystromformer

transformer = Nystromformer(
        dim = d_channel,
        dim_head = 64,
        depth = 8,
        heads = 8,
        num_landmarks = 16,
        pinv_iterations = 6
        )

model = ViT_efficient(
    dim = d_channel,
    image_size = image_size,
    patch_size = patch_size,
    num_classes = num_classes,
    transformer = transformer
)
print(f"Total number of parameters in the model: {count_parameters(model)}")
model

Total number of parameters in the model: 13292069


ViT(
  (to_patch_embedding): Sequential(
    (0): Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=8, p2=8)
    (1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
    (2): Linear(in_features=192, out_features=512, bias=True)
    (3): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (transformer): Nystromformer(
    (layers): ModuleList(
      (0-7): 8 x ModuleList(
        (0): PreNorm(
          (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (fn): NystromAttention(
            (to_qkv): Linear(in_features=512, out_features=1536, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=512, out_features=512, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
            (res_conv): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)
          )
        )
        (1): PreNorm(
          (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
     

# Longformer

In [47]:
from transformers import LongformerModel, LongformerConfig
from models.longformer import ViT_for_longformer

In [53]:
config = LongformerConfig(
    attention_window=128,
    hidden_size=int(d_channel/2),
    num_attention_heads=8, 
    num_hidden_layers=8, 
    max_position_embeddings=4097,
    intermediate_size=d_channel
    )
transformer = LongformerModel(config)

model = ViT_for_longformer(
    dim = d_channel,
    image_size = image_size,
    patch_size = patch_size,
    num_classes = num_classes,
    transformer = transformer
)

print(f"Total number of parameters in the model: {count_parameters(model)}")
model

Total number of parameters in the model: 15403493


ViT_for_longformer(
  (to_patch_embedding): Sequential(
    (0): Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=8, p2=8)
    (1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
    (2): Linear(in_features=192, out_features=512, bias=True)
    (3): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (transformer): LongformerModel(
    (embeddings): LongformerEmbeddings(
      (word_embeddings): Embedding(30522, 256, padding_idx=1)
      (token_type_embeddings): Embedding(2, 256)
      (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (position_embeddings): Embedding(4097, 256, padding_idx=1)
    )
    (encoder): LongformerEncoder(
      (layer): ModuleList(
        (0-7): 8 x LongformerLayer(
          (attention): LongformerAttention(
            (self): LongformerSelfAttention(
              (query): Linear(in_features=256, out_features=256, bias=True)
              (key): Linear(in_features=2

# Reformer

In [58]:
from models.reformer import ViT_for_reformer

transformer = Reformer(
    dim = d_channel,
    depth = 8,
    heads = 8,
    bucket_size = 32,
    n_hashes = 1,
    ff_mult = 1,
    reverse_thres = 4096,
    lsh_dropout = 0,
    causal = False
    )


model = ViT_for_reformer(
    dim = d_channel,
    image_size = image_size,
    patch_size = patch_size,
    num_classes = num_classes,
    transformer = transformer
)

print(f"Total number of parameters in the model: {count_parameters(model)}")
model

Total number of parameters in the model: 11191781


ViT_for_reformer(
  (to_patch_embedding): Sequential(
    (0): Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=8, p2=8)
    (1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
    (2): Linear(in_features=192, out_features=512, bias=True)
    (3): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (transformer): Reformer(
    (layers): ReversibleSequence(
      (blocks): ModuleList(
        (0-7): 8 x ReversibleBlock(
          (f): Deterministic(
            (net): PreNorm(
              (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
              (fn): LSHSelfAttention(
                (toqk): Linear(in_features=512, out_features=512, bias=False)
                (tov): Linear(in_features=512, out_features=512, bias=False)
                (to_out): Linear(in_features=512, out_features=512, bias=True)
                (lsh_attn): LSHAttention(
                  (dropout): Dropout(p=0, inplace=False)
                  (dropout_for_hash): Dropout(p=0.0