In [2]:
import timm
import torch
import torch.nn as nn
from timm.models.vision_transformer import VisionTransformer
from xformers.components.attention import ScaledDotProduct
from xformers.helpers.timm_sparse_attention import TimmSparseAttention

img_size = 224
patch_size = 16

In [3]:
# Get a reference ViT model
model = VisionTransformer(img_size=img_size, patch_size=patch_size,
                            embed_dim=96, depth=8, num_heads=8, mlp_ratio=3.,
                            qkv_bias=False, norm_layer=nn.LayerNorm).cuda()


In [7]:
# Define the mask that we want to use
# We suppose in this snipper that you have a precise mask in mind already
# but several helpers and examples are proposed in  `xformers.components.attention.attention_patterns`
my_fancy_mask = torch.ones(1, 1, 1, 1).cuda()

In [9]:
# Define a recursive monkey patching function
def replace_attn_with_xformers_one(module, att_mask):
    module_output = module
    if isinstance(module, timm.models.vision_transformer.Attention):
        qkv = module.qkv
        dim = qkv.weight.shape[1] * module.num_heads
        # Extra parameters can be exposed in TimmSparseAttention, this is a minimal example
        module_output = TimmSparseAttention(dim, module.num_heads, attn_mask=att_mask)
    for name, child in module.named_children():
        module_output.add_module(name, replace_attn_with_xformers_one(child, att_mask))
    del module

    return module_output


In [10]:
# Now we can just patch our reference model, and get a sparse-aware variation
sparse_model = replace_attn_with_xformers_one(model, my_fancy_mask)

In [11]:
sparse_model

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 96, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
      (attn): TimmSparseAttention(
        (qkv): Linear(in_features=96, out_features=288, bias=False)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=96, out_features=96, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=96, out_features=288, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (fc2): Linear(in_features=288, out_features=96, bias=True)
        (drop2): Dropout(p=0.0, inplace=F