In [1]:
%load_ext autoreload

In [2]:
%autoreload 2

In [3]:
from PIL import Image
import torch
from torchvision import transforms

In [4]:
img = Image.open('../examples/simple/img.jpg')
img = transforms.Compose([transforms.Resize((384, 384)), transforms.ToTensor(), transforms.Normalize(0.5, 0.5)])(img).unsqueeze(0)
img.shape

torch.Size([1, 3, 384, 384])

In [5]:
from importlib import reload
import pytorch_pretrained_vit
reload(pytorch_pretrained_vit)

<module 'pytorch_pretrained_vit' from '/home/luke/projects/experiments/ViT-PyTorch/pytorch_pretrained_vit/__init__.py'>

In [6]:
# model = pytorch_pretrained_vit.ViT(name='B_16', pretrained=False, num_classes=21843)
model = pytorch_pretrained_vit.ViT(name='B_16_imagenet1k', pretrained=False, num_classes=1000)

In [7]:
# list(model.state_dict().keys())

### Jax

In [8]:
import numpy as np

In [9]:
# npz = np.load('imagenet21k_ViT-B_16.npz')
npz = np.load('ViT-B_16.npz')

In [10]:
# npz.files

In [11]:
def convert(npz, state_dict):
    new_state_dict = {}
    pytorch_k2v = {jax_to_pytorch(k): v for k, v in npz.items()}
    for pytorch_k, pytorch_v in state_dict.items():
        
        # Naming
        if 'self_attn.out_proj.weight' in pytorch_k:
            v = pytorch_k2v[pytorch_k]
            v = v.reshape(v.shape[0] * v.shape[1], v.shape[2])
        elif 'self_attn.in_proj_' in pytorch_k:
            v = np.stack((pytorch_k2v[pytorch_k + '*q'], 
                          pytorch_k2v[pytorch_k + '*k'], 
                          pytorch_k2v[pytorch_k + '*v']), axis=0)
        else:
            if pytorch_k not in pytorch_k2v:
                print(pytorch_k, list(pytorch_k2v.keys()))
                assert False
            v = pytorch_k2v[pytorch_k]
        v = torch.from_numpy(v)
        
        # Sizing
        if '.weight' in pytorch_k:
            if len(pytorch_v.shape) == 2:
                v = v.transpose(0, 1)
            if len(pytorch_v.shape) == 4:
                v = v.permute(3, 2, 0, 1)
        if ('proj.weight' in pytorch_k):
            v = v.transpose(0, 1)
            v = v.reshape(-1, v.shape[-1]).T
        if ('proj.bias' in pytorch_k):
            print(pytorch_k, v.shape)
        if ('attn.proj_' in pytorch_k and 'weight' in pytorch_k):
            v = v.permute(0, 2, 1)
            v = v.reshape(-1, v.shape[-1])
        if 'attn.proj_' in pytorch_k and 'bias' in pytorch_k:
            v = v.reshape(-1)
        new_state_dict[pytorch_k] = v
    return new_state_dict

In [12]:
def jax_to_pytorch(k):
    k = k.replace('Transformer/encoder_norm', 'norm')
    k = k.replace('LayerNorm_0', 'norm1')
    k = k.replace('LayerNorm_2', 'norm2')
    k = k.replace('MlpBlock_3/Dense_0', 'pwff.fc1')
    k = k.replace('MlpBlock_3/Dense_1', 'pwff.fc2')
    k = k.replace('MultiHeadDotProductAttention_1/out', 'proj')
    k = k.replace('MultiHeadDotProductAttention_1/query', 'attn.proj_q')
    k = k.replace('MultiHeadDotProductAttention_1/key', 'attn.proj_k')
    k = k.replace('MultiHeadDotProductAttention_1/value', 'attn.proj_v')
    k = k.replace('Transformer/posembed_input', 'positional_embedding')
    k = k.replace('encoderblock_', 'blocks.')
    k = 'patch_embedding.bias' if k == 'embedding/bias' else k
    k = 'patch_embedding.weight' if k == 'embedding/kernel' else k
    k = 'class_token' if k == 'cls' else k
    k = k.replace('head', 'fc')
    k = k.replace('kernel', 'weight')
    k = k.replace('scale', 'weight')
    k = k.replace('/', '.')
    k = k.lower()
    return k

In [13]:
new_state_dict = convert(npz, model.state_dict())

transformer.blocks.0.proj.bias torch.Size([768])
transformer.blocks.1.proj.bias torch.Size([768])
transformer.blocks.2.proj.bias torch.Size([768])
transformer.blocks.3.proj.bias torch.Size([768])
transformer.blocks.4.proj.bias torch.Size([768])
transformer.blocks.5.proj.bias torch.Size([768])
transformer.blocks.6.proj.bias torch.Size([768])
transformer.blocks.7.proj.bias torch.Size([768])
transformer.blocks.8.proj.bias torch.Size([768])
transformer.blocks.9.proj.bias torch.Size([768])
transformer.blocks.10.proj.bias torch.Size([768])
transformer.blocks.11.proj.bias torch.Size([768])


In [14]:
model.load_state_dict(new_state_dict)

<All keys matched successfully>

In [15]:
import json 

def check(M):
    labels_map = json.load(open('../examples/simple/labels_map.txt'))
    labels_map = [labels_map[str(i)] for i in range(1000)]
    with torch.no_grad():
        outputs = M(img)
    print('-----')
    for idx in torch.topk(outputs, k=5).indices.squeeze(0).tolist():
        prob = torch.softmax(outputs, dim=1)[0, idx].item()
        print('[{idx}] {label:<75} ({p:.2f}%)'.format(idx=idx, label=labels_map[idx], p=prob*100))

In [16]:
model.eval()
check(model)

-----
[388] giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca           (99.51%)
[387] lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens         (0.16%)
[297] sloth bear, Melursus ursinus, Ursus ursinus                                 (0.05%)
[295] American black bear, black bear, Ursus americanus, Euarctos americanus      (0.03%)
[296] ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus                 (0.03%)


In [17]:
def printhook(self, input, output):
    print('Inside ' + self.__class__.__name__ + ' forward')
    print('input: ', type(input))
    print('input[0]: ', type(input[0]))
    print('input size:', input[0].size())
    print('input norm:', input[0].norm())
    if isinstance(output, tuple):
        output = output[0]
    print('output size:', output.data.size())
    print('output norm:', output.data.norm())
    print('-----------\n')

In [18]:
# h1 = m.blocks[0].attn.proj.register_forward_hook(printhook)  # m.blocks[0].register_forward_hook(printhook) 
# h2 = model.transformer.blocks[0].proj.register_forward_hook(printhook)  # model.transformer.layers[0].register_forward_hook(printhook) 
# m(img)
# model(img)
# h1.remove()
# h2.remove()