In [2]:
import torch
from models.vpt_official import build_promptmodel
from models.adapter import build_adapter_model
from models.bias import build_bias_model
import timm

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
weights = torch.load('head.pt')

In [21]:
new

{'head.weight': tensor([[ 0.0255, -0.0422,  0.0124,  ...,  0.0349,  0.0303,  0.0296],
         [-0.0371, -0.0056, -0.0098,  ...,  0.0175, -0.0293,  0.0156],
         [-0.0254, -0.0100,  0.0217,  ...,  0.0048, -0.0161, -0.0015],
         ...,
         [ 0.0184,  0.0269, -0.0064,  ..., -0.0155, -0.0431, -0.0590],
         [-0.0106, -0.0219, -0.0169,  ...,  0.0415,  0.0140, -0.0437],
         [-0.0216,  0.0328,  0.0309,  ...,  0.0303, -0.0291,  0.0074]],
        device='cuda:3'),
 'head.bias': tensor([-0.0032, -0.0040,  0.0005, -0.0251, -0.0265, -0.0235,  0.0156, -0.0095,
          0.0005,  0.0155,  0.0016, -0.0222,  0.0232, -0.0180, -0.0040, -0.0106,
         -0.0300,  0.0234,  0.0005,  0.0149, -0.0181,  0.0024,  0.0183,  0.0190,
          0.0015,  0.0296,  0.0254,  0.0073, -0.0030, -0.0248,  0.0051,  0.0179,
         -0.0152,  0.0259, -0.0299, -0.0143,  0.0006,  0.0181,  0.0128, -0.0130,
          0.0306, -0.0024,  0.0140, -0.0352, -0.0093,  0.0177, -0.0328,  0.0200,
          0.0104,  

In [22]:
def get_model(type, weights):
    vit_type = 'vit_base_patch16_224_in21k'
    class_num = 100
    if type in ['adapter', 'prompt', 'bias', 'head']:
        basic_model = timm.create_model(vit_type, num_classes= class_num, pretrained= True)
        if type =='adapter':
            reducation_factor = 8
            model = build_adapter_model(basic_model=basic_model, 
                                        num_classes=class_num,
                                        reducation_factor=reducation_factor
                                    )
        elif type in ['prompt']:
            prompt_num = 10 if type=='prompt' else 0
            model = build_promptmodel(basic_model=basic_model, 
                                num_classes=class_num,
                                vpt_type='Deep',
                                prompt_num=prompt_num,
                                edge_size=224,
                                patch_size=16,
                                projection = -1,
                                prompt_drop_rate= 0.1,
                            )
        elif type=='bias':
            model = build_bias_model(vit_type, num_classes=class_num)
    if type in ['pretrain', 'head']:
        model = timm.create_model(vit_type, num_classes= class_num, pretrained= True)
        new = {}
        for k,v in weights.items():
            if 'basic_model' in k:
                new_k = k.replace('basic_model.',"")
                new.update({new_k:v})
        weights = new
    model.load_state_dict(weights, strict=False)

    return model

In [23]:
model =get_model('head', weights)

In [24]:
model

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate=none)
        (drop1): Dropout(p=0.0, inplace=False)
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop2): Dropout(p=0.0, inplace=False)
      )
      (ls2): 

In [12]:
model = torch.nn.DataParallel(model.cuda())

In [8]:
model.eval()

VPT_ViT(
  (basic_model): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (blocks): Sequential(
      (0): Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (ls1): Identity()
        (drop_path1): Identity()
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (act): GELU(approximate=none)
          (drop1): Dropout(p=0.0, inplace=False)
          (fc2): Linear(in_features=3072, out_features=768, bias=True

In [16]:
a = torch.ones((12,0,224,224))

In [17]:
a[0]

tensor([], size=(0, 224, 224))

In [13]:
model(torch.ones((1,3,224,224)))

tensor([[-0.7253, -0.4402,  0.1638,  0.9980,  0.7614,  0.7058, -0.6057,  0.2263,
         -0.7071, -1.3575,  0.2277, -0.8824, -0.9001, -0.0528,  0.3122,  0.8744,
         -1.0467,  0.8515, -1.7760,  0.3837,  0.0295,  1.1543, -0.7211,  0.3142,
          0.9627,  0.5895, -0.3074,  0.3187,  0.7116,  0.2369,  0.4716,  1.4654,
         -0.0281, -1.4352,  0.1938,  0.4295,  0.5904, -0.0876,  0.0358, -1.3185,
         -0.4680, -0.2265,  1.0080,  0.5477, -0.3375, -0.0193, -1.1729,  1.2254,
         -0.4107, -0.1894, -0.5551, -1.0455,  1.2511,  0.0436,  0.3652,  0.6433,
          0.3429, -0.2434,  0.1228,  0.7170,  1.6617,  0.2985, -0.9505,  0.2945,
         -1.1439,  0.0555, -0.9318,  0.4139,  0.2189, -0.4725, -0.1848,  0.1794,
         -0.0693,  0.3499, -0.8854, -0.6011,  0.1286, -0.6979, -0.3282, -1.2995,
         -0.1592, -0.0790, -0.4734, -0.0977, -0.3569,  0.7816, -0.6122, -0.9811,
          1.3498,  0.1259,  0.3068,  1.7419,  0.2985, -0.2085,  0.8939, -0.1317,
          0.5712,  0.2333,  

In [21]:
model.state_dict()

{'basic_model.head.weight': tensor([[ 0.0224, -0.0261,  0.0189,  ...,  0.0144,  0.0219,  0.0320],
         [-0.0322,  0.0128,  0.0263,  ..., -0.0030, -0.0275, -0.0160],
         [-0.0039,  0.0203,  0.0200,  ..., -0.0093,  0.0335,  0.0088],
         ...,
         [ 0.0270,  0.0254,  0.0223,  ..., -0.0214, -0.0114, -0.0344],
         [-0.0158, -0.0157,  0.0183,  ...,  0.0277,  0.0158,  0.0021],
         [-0.0343, -0.0153,  0.0068,  ...,  0.0311, -0.0055,  0.0076]]),
 'basic_model.head.bias': tensor([-0.0018, -0.0015,  0.0017, -0.0182, -0.0230, -0.0235,  0.0202, -0.0089,
          0.0004,  0.0152, -0.0003, -0.0231,  0.0215, -0.0164, -0.0029, -0.0075,
         -0.0331,  0.0270, -0.0061,  0.0196, -0.0209,  0.0055,  0.0183,  0.0070,
          0.0062,  0.0283,  0.0274,  0.0057, -0.0041, -0.0266,  0.0062,  0.0189,
         -0.0164,  0.0224, -0.0304, -0.0171,  0.0052,  0.0184,  0.0132, -0.0155,
          0.0297, -0.0020,  0.0172, -0.0314, -0.0159,  0.0206, -0.0345,  0.0227,
          0.0121,  0

[True,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 