In [1]:
%load_ext autoreload

In [2]:
%autoreload 2

In [3]:
from PIL import Image
import torch
from torchvision import transforms

In [4]:
img = Image.open('../examples/simple/img.jpg')
img = transforms.Compose([transforms.Resize((384, 384)), transforms.ToTensor(), transforms.Normalize(0.5, 0.5)])(img).unsqueeze(0)
img.shape

torch.Size([1, 3, 384, 384])

In [5]:
from importlib import reload
import vit_pytorch
reload(vit_pytorch)

<module 'vit_pytorch' from '/home/luke/projects/experiments/ViT-PyTorch/vit_pytorch/__init__.py'>

In [6]:
# model = vit_pytorch.ViT(name='B_16', pretrained=False, num_classes=21843)
model = vit_pytorch.ViT(name='B_16_imagenet1k', pretrained=False, num_classes=1000)

In [7]:
# list(model.state_dict().keys())

### Jax

In [8]:
import numpy as np

In [9]:
# npz = np.load('imagenet21k_ViT-B_16.npz')
npz = np.load('ViT-B_16.npz')

In [10]:
# npz.files

In [114]:
def convert(npz, state_dict):
    new_state_dict = {}
    pytorch_k2v = {jax_to_pytorch(k): v for k, v in npz.items()}
    for pytorch_k, pytorch_v in state_dict.items():
        
        # Naming
        if 'self_attn.out_proj.weight' in pytorch_k:
            v = pytorch_k2v[pytorch_k]
            v = v.reshape(v.shape[0] * v.shape[1], v.shape[2])
        elif 'self_attn.in_proj_' in pytorch_k:
            v = np.stack((pytorch_k2v[pytorch_k + '*q'], 
                          pytorch_k2v[pytorch_k + '*k'], 
                          pytorch_k2v[pytorch_k + '*v']), axis=0)
        else:
            if pytorch_k not in pytorch_k2v:
                print(pytorch_k, list(pytorch_k2v.keys()))
                assert False
            v = pytorch_k2v[pytorch_k]
        v = torch.from_numpy(v)
        
        # Sizing
        if '.weight' in pytorch_k:
            if len(pytorch_v.shape) == 2:
                v = v.transpose(0, 1)
            if len(pytorch_v.shape) == 4:
                v = v.permute(3, 2, 0, 1)
        if ('proj.weight' in pytorch_k):
            v = v.transpose(0, 1)
            v = v.reshape(-1, v.shape[-1]).T
        if ('proj.bias' in pytorch_k):
            print(pytorch_k, v.shape)
        if ('attn.proj_' in pytorch_k and 'weight' in pytorch_k):
            v = v.permute(0, 2, 1)
            v = v.reshape(-1, v.shape[-1])
        if 'attn.proj_' in pytorch_k and 'bias' in pytorch_k:
            v = v.reshape(-1)
        new_state_dict[pytorch_k] = v
    return new_state_dict

In [115]:
def jax_to_pytorch(k):
    k = k.replace('Transformer/encoder_norm', 'norm')
    k = k.replace('LayerNorm_0', 'norm1')
    k = k.replace('LayerNorm_2', 'norm2')
    k = k.replace('MlpBlock_3/Dense_0', 'pwff.fc1')
    k = k.replace('MlpBlock_3/Dense_1', 'pwff.fc2')
    k = k.replace('MultiHeadDotProductAttention_1/out', 'proj')
    k = k.replace('MultiHeadDotProductAttention_1/query', 'attn.proj_q')
    k = k.replace('MultiHeadDotProductAttention_1/key', 'attn.proj_k')
    k = k.replace('MultiHeadDotProductAttention_1/value', 'attn.proj_v')
    k = k.replace('Transformer/posembed_input', 'positional_embedding')
    k = k.replace('encoderblock_', 'blocks.')
    k = 'patch_embedding.bias' if k == 'embedding/bias' else k
    k = 'patch_embedding.weight' if k == 'embedding/kernel' else k
    k = 'class_token' if k == 'cls' else k
    k = k.replace('head', 'fc')
    k = k.replace('kernel', 'weight')
    k = k.replace('scale', 'weight')
    k = k.replace('/', '.')
    k = k.lower()
    return k

In [116]:
new_state_dict = convert(npz, model.state_dict())

transformer.blocks.0.proj.bias torch.Size([768])
transformer.blocks.1.proj.bias torch.Size([768])
transformer.blocks.2.proj.bias torch.Size([768])
transformer.blocks.3.proj.bias torch.Size([768])
transformer.blocks.4.proj.bias torch.Size([768])
transformer.blocks.5.proj.bias torch.Size([768])
transformer.blocks.6.proj.bias torch.Size([768])
transformer.blocks.7.proj.bias torch.Size([768])
transformer.blocks.8.proj.bias torch.Size([768])
transformer.blocks.9.proj.bias torch.Size([768])
transformer.blocks.10.proj.bias torch.Size([768])
transformer.blocks.11.proj.bias torch.Size([768])


In [117]:
model.load_state_dict(new_state_dict)

<All keys matched successfully>

In [118]:
import json 

def check(M):
    labels_map = json.load(open('../examples/simple/labels_map.txt'))
    labels_map = [labels_map[str(i)] for i in range(1000)]
    with torch.no_grad():
        outputs = M(img)
    print('-----')
    for idx in torch.topk(outputs, k=5).indices.squeeze(0).tolist():
        prob = torch.softmax(outputs, dim=1)[0, idx].item()
        print('{label:<75} ({p:.2f}%)'.format(label=labels_map[idx], p=prob*100))

In [119]:
import timm
m = timm.create_model('vit_base_patch16_384', pretrained=True)
m.eval()
check(m)

-----
giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca           (99.51%)
lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens         (0.16%)
sloth bear, Melursus ursinus, Ursus ursinus                                 (0.05%)
American black bear, black bear, Ursus americanus, Euarctos americanus      (0.03%)
ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus                 (0.03%)


In [120]:
model.eval()
check(model)

-----
giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca           (99.51%)
lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens         (0.16%)
sloth bear, Melursus ursinus, Ursus ursinus                                 (0.05%)
American black bear, black bear, Ursus americanus, Euarctos americanus      (0.03%)
ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus                 (0.03%)


In [121]:
def printhook(self, input, output):
    print('Inside ' + self.__class__.__name__ + ' forward')
    print('input: ', type(input))
    print('input[0]: ', type(input[0]))
    print('input size:', input[0].size())
    print('input norm:', input[0].norm())
    if isinstance(output, tuple):
        output = output[0]
    print('output size:', output.data.size())
    print('output norm:', output.data.norm())
    print('-----------\n')

In [122]:
m.blocks[0].attn.proj.weight.data # [3,7].item()
# m.blocks[0].attn.proj.weight.data - model.transformer.blocks[0].proj.weight.data
# torch.all(m.blocks[0].attn.proj.weight.data == model.transformer.blocks[0].proj.weight.data)
# m.blocks[0].attn.proj.weight.data.norm(), model.transformer.blocks[0].proj.weight.data.norm()

tensor([[-0.0278,  0.0162,  0.0049,  ..., -0.0356,  0.0098, -0.0595],
        [-0.0901, -0.0046,  0.0499,  ..., -0.0351, -0.0044,  0.0291],
        [ 0.0160, -0.0356, -0.0768,  ...,  0.0045, -0.0360,  0.0005],
        ...,
        [ 0.0552,  0.0912,  0.0156,  ...,  0.0040,  0.0362,  0.0054],
        [-0.0435, -0.1232,  0.0398,  ..., -0.0223, -0.0014,  0.0225],
        [ 0.0034,  0.0090, -0.0285,  ...,  0.0361, -0.0960,  0.0245]])

In [123]:
model.transformer.blocks[0].proj.weight.data

tensor([[-0.0278,  0.0162,  0.0049,  ..., -0.0356,  0.0098, -0.0595],
        [-0.0901, -0.0046,  0.0499,  ..., -0.0351, -0.0044,  0.0291],
        [ 0.0160, -0.0356, -0.0768,  ...,  0.0045, -0.0360,  0.0005],
        ...,
        [ 0.0552,  0.0912,  0.0156,  ...,  0.0040,  0.0362,  0.0054],
        [-0.0435, -0.1232,  0.0398,  ..., -0.0223, -0.0014,  0.0225],
        [ 0.0034,  0.0090, -0.0285,  ...,  0.0361, -0.0960,  0.0245]])

In [124]:
h1 = m.blocks[0].attn.proj.register_forward_hook(printhook)  # m.blocks[0].register_forward_hook(printhook) 
h2 = model.transformer.blocks[0].proj.register_forward_hook(printhook)  # model.transformer.layers[0].register_forward_hook(printhook) 
m(img)
model(img)
h1.remove()
h2.remove()

Inside Linear forward
input:  <class 'tuple'>
input[0]:  <class 'torch.Tensor'>
input size: torch.Size([1, 577, 768])
input norm: tensor(192.5844, grad_fn=<NormBackward1>)
output size: torch.Size([1, 577, 768])
output norm: tensor(284.1073)
-----------

Inside Linear forward
input:  <class 'tuple'>
input[0]:  <class 'torch.Tensor'>
input size: torch.Size([1, 577, 768])
input norm: tensor(192.5844, grad_fn=<NormBackward1>)
output size: torch.Size([1, 577, 768])
output norm: tensor(284.1073)
-----------



In [68]:
# torch.all(model.positional_embedding.pos_embedding.data == m.pos_embed.data)
# torch.all(m.patch_embed.proj.weight.data == model.patch_embedding.weight.data)
# torch.all(model.class_token == m.cls_token)
# torch.all(m.patch_embed.weight.data == model.transformer.layers[3].self_attn.in_proj_weight.data)
# torch.all(m.blocks[3].attn.qkv.weight.data == model.transformer.layers[3].self_attn.in_proj_weight.data)
# torch.all(m.blocks[11].attn.qkv.bias.data == model.transformer.layers[11].self_attn.in_proj_bias.data)
# torch.all(m.blocks[11].attn.proj.weight.data == model.transformer.layers[11].self_attn.out_proj.weight.data)
# torch.all(m.blocks[11].attn.qkv.weight.data == model.transformer.layers[11].self_attn.in_proj_weight.data)
# torch.all(m.blocks[11].mlp.fc1.weight.data == model.transformer.layers[11].linear1.weight.data)
# torch.all(m.blocks[11].norm2.weight.data == model.transformer.layers[11].norm2.weight.data)
# torch.all(m.blocks[11].norm1.weight.data == model.transformer.layers[11].norm1.weight.data)
# torch.all(m.norm.weight.data == model.norm.weight.data)
# torch.all(m.head.weight.data == model.fc.weight.data)

In [None]:
# def jax_to_pytorch(k):
#     k = k.replace('Transformer/encoder_norm', 'norm')
#     k = k.replace('LayerNorm_0', 'norm1')
#     k = k.replace('LayerNorm_2', 'norm2')
#     k = k.replace('MlpBlock_3/Dense_0', 'linear1')
#     k = k.replace('MlpBlock_3/Dense_1', 'linear2')
#     k = k.replace('MultiHeadDotProductAttention_1/out', 'self_attn.out_proj')
#     k = k.replace('MultiHeadDotProductAttention_1/query/kernel', 'self_attn.in_proj_weight*q')
#     k = k.replace('MultiHeadDotProductAttention_1/key/kernel', 'self_attn.in_proj_weight*k')
#     k = k.replace('MultiHeadDotProductAttention_1/value/kernel', 'self_attn.in_proj_weight*v')
#     k = k.replace('MultiHeadDotProductAttention_1/query/bias', 'self_attn.in_proj_bias*q')
#     k = k.replace('MultiHeadDotProductAttention_1/key/bias', 'self_attn.in_proj_bias*k')
#     k = k.replace('MultiHeadDotProductAttention_1/value/bias', 'self_attn.in_proj_bias*v')
#     k = k.replace('Transformer/posembed_input', 'positional_embedding')
#     k = k.replace('encoderblock_', 'layers.')
#     k = 'patch_embedding.bias' if k == 'embedding/bias' else k
#     k = 'patch_embedding.weight' if k == 'embedding/kernel' else k
#     k = 'class_token' if k == 'cls' else k
#     k = k.replace('head', 'fc')
#     k = k.replace('kernel', 'weight')
#     k = k.replace('scale', 'weight')
#     k = k.replace('/', '.')
#     k = k.lower()
#     return k