In [1]:
from vim.models_mamba import VisionMamba
from pathlib import PurePath
import torch
import sys
import os
# Get the absolute path of the vim directory
vim_path = os.path.abspath('vim')
# Add the vim directory to the system path
sys.path.append(vim_path)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from huggingface_hub import snapshot_download

VIM_REPO = "hustvl/Vim-small-midclstok"

pretrained_model_dir = snapshot_download(
    repo_id=VIM_REPO,
    # Comment the next line the first time to have the files be
    # downloaded.
    local_files_only=True
)

MODEL_FILE = PurePath(pretrained_model_dir, "vim_s_midclstok_ft_81p6acc.pth")
print(MODEL_FILE)

/home/eh_abdol/.cache/huggingface/hub/models--hustvl--Vim-small-midclstok/snapshots/babc4440f5fab6e08d97e371afa639c8cf98bf2c/vim_s_midclstok_ft_81p6acc.pth


In [3]:
model = VisionMamba(
    patch_size=16,
    stride=8,
    embed_dim=384,
    depth=24,
    rms_norm=True,
    residual_in_fp32=True,
    fused_add_norm=True,
    final_pool_type='mean',
    if_abs_pos_embed=True,
    if_rope=False,
    if_rope_residual=False,
    bimamba_type="v2",
    if_cls_token=True,
    if_devide_out=True,
    use_middle_cls_token=True,
    num_classes=1000,
    drop_rate=0.0,
    drop_path_rate=0.1,
    drop_block_rate=None,
    img_size=224,
)

In [4]:
print(model)

VisionMamba(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(8, 8))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (head): Linear(in_features=384, out_features=1000, bias=True)
  (drop_path): DropPath()
  (layers): ModuleList(
    (0-1): 2 x Block(
      (mixer): Mamba(
        (in_proj): Linear(in_features=384, out_features=1536, bias=False)
        (conv1d): Conv1d(768, 768, kernel_size=(4,), stride=(1,), padding=(3,), groups=768)
        (act): SiLU()
        (x_proj): Linear(in_features=768, out_features=56, bias=False)
        (dt_proj): Linear(in_features=24, out_features=768, bias=True)
        (conv1d_b): Conv1d(768, 768, kernel_size=(4,), stride=(1,), padding=(3,), groups=768)
        (x_proj_b): Linear(in_features=768, out_features=56, bias=False)
        (dt_proj_b): Linear(in_features=24, out_features=768, bias=True)
        (out_proj): Linear(in_features=768, out_features=384, bias=False)
      )
      (norm):

In [6]:
checkpoint = torch.load(str(MODEL_FILE), map_location='cpu')
# Important: make sure the values of this match what's used to instantiate the VisionMamba class.
# If not, loading the checkpoint will fail.
checkpoint["args"]

model.load_state_dict(checkpoint["model"])

<All keys matched successfully>

In [7]:
model.eval()
model.to("cuda")

VisionMamba(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(8, 8))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (head): Linear(in_features=384, out_features=1000, bias=True)
  (drop_path): DropPath()
  (layers): ModuleList(
    (0-1): 2 x Block(
      (mixer): Mamba(
        (in_proj): Linear(in_features=384, out_features=1536, bias=False)
        (conv1d): Conv1d(768, 768, kernel_size=(4,), stride=(1,), padding=(3,), groups=768)
        (act): SiLU()
        (x_proj): Linear(in_features=768, out_features=56, bias=False)
        (dt_proj): Linear(in_features=24, out_features=768, bias=True)
        (conv1d_b): Conv1d(768, 768, kernel_size=(4,), stride=(1,), padding=(3,), groups=768)
        (x_proj_b): Linear(in_features=768, out_features=56, bias=False)
        (dt_proj_b): Linear(in_features=24, out_features=768, bias=True)
        (out_proj): Linear(in_features=768, out_features=384, bias=False)
      )
      (norm):

In [8]:
from PIL import Image
from torchvision import transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

test_image = Image.open("test.jpg")
test_image = test_image.resize((224, 224))
image_as_tensor = transforms.ToTensor()(test_image)
normalized_tensor = transforms.Normalize(
    IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(image_as_tensor)

In [9]:
# test_image
x = normalized_tensor.unsqueeze(0).cuda()
pred = model(x)
# Note: the returned label can be verified with https://deeplearning.cms.waikato.ac.nz/user-guide/class-maps/IMAGENET/
pred.argmax()

tensor(0, device='cuda:0')