In [1]:
from mae.models_vit import vit_base_patch16
import torch
import torch.nn as nn
from torch.nn.modules.utils import _pair
import torch
import torch.nn as nn
import numpy as np

from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
from torch.nn.modules.utils import _pair
import copy
import math

In [2]:
checkpoint = torch.load('/home/lili/code/ssl/ssl-medical-sattelite/mae/mae_baseline_medical/checkpoint-49.pth', map_location='cpu')

In [3]:
checkpoint_model = checkpoint['model']

In [4]:
model_mae = vit_base_patch16()
msg = model_mae.load_state_dict(checkpoint_model, strict=False)

In [5]:
import ml_collections
def get_b16_config():
    """Returns the ViT-B/16 configuration."""
    config = ml_collections.ConfigDict()
    config.patches = ml_collections.ConfigDict({'size': (16, 16)})
    config.hidden_size = 768
    config.transformer = ml_collections.ConfigDict()
    config.transformer.mlp_dim = 3072
    config.transformer.num_heads = 12
    config.transformer.num_layers = 12
    config.transformer.attention_dropout_rate = 0.0
    config.transformer.dropout_rate = 0.1

    config.classifier = 'seg'
    config.representation_size = None
    config.resnet_pretrained_path = None
    # config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_16.npz'
    config.pretrained_path = '/home/lili/code/ssl/ssl-medical-sattelite/mae/mae_baseline_medical/checkpoint-49.pth'
    config.patch_size = 16

    config.decoder_channels = (256, 128, 64, 16)
    config.n_classes = 2
    config.activation = 'softmax'
    return config

In [6]:
class VisionTransformer(nn.Module):
    def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):
        super(VisionTransformer, self).__init__()
        self.num_classes = num_classes
        self.zero_head = zero_head
        self.classifier = config.classifier
        self.transformer = Transformer(config, img_size, vis)
#         self.decoder = DecoderCup(config)
#         self.segmentation_head = SegmentationHead(
#             in_channels=config['decoder_channels'][-1],
#             out_channels=config['n_classes'],
#             kernel_size=3,
#         )
        self.config = config

    def forward(self, x):
        if x.size()[1] == 1:
            x = x.repeat(1,3,1,1)
        x, attn_weights, features = self.transformer(x)  # (B, n_patch, hidden)
        x = self.decoder(x, features)
        logits = self.segmentation_head(x)
        return logits

    def load_from_mae(self, weights):
        with torch.no_grad():

            res_weight = weights
            
            self.transformer.embeddings.patch_embeddings.weight.copy_(weights["patch_embed.proj.weight"])
            self.transformer.embeddings.patch_embeddings.bias.copy_(weights["patch_embed.proj.bias"])

            self.transformer.encoder.encoder_norm.weight.copy_(weights["norm.weight"])
            self.transformer.encoder.encoder_norm.bias.copy_(weights["norm.weight"])

            posemb = weights["pos_embed"]

            posemb_new = self.transformer.embeddings.position_embeddings
        
            if posemb.size() == posemb_new.size():
                self.transformer.embeddings.position_embeddings.copy_(posemb)
            elif posemb.size()[1]-1 == posemb_new.size()[1]:
                posemb = posemb[:, 1:]
                self.transformer.embeddings.position_embeddings.copy_(posemb)
            else:
                logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size()))
                ntok_new = posemb_new.size(1)
                if self.classifier == "seg":
                    _, posemb_grid = posemb[:, :1], posemb[0, 1:]
                gs_old = int(np.sqrt(len(posemb_grid)))
                gs_new = int(np.sqrt(ntok_new))
                print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new))
                posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)
                zoom = (gs_new / gs_old, gs_new / gs_old, 1)
                posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1)  # th2np
                posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1)
                posemb = posemb_grid
                self.transformer.embeddings.position_embeddings.copy_(np2th(posemb))

            # Encoder whole
            for bname, block in self.transformer.encoder.named_children(): 
              
                for uname, unit in block.named_children():
                    unit.load_from_mae(weights, n_block=uname)
            
            if self.transformer.embeddings.hybrid:
                self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(res_weight["conv_root/kernel"], conv=True))
                gn_weight = np2th(res_weight["gn_root/scale"]).view(-1)
                gn_bias = np2th(res_weight["gn_root/bias"]).view(-1)
                self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight)
                self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias)

                for bname, block in self.transformer.embeddings.hybrid_model.body.named_children():
                    print(bname, block )
                    for uname, unit in block.named_children():
                        unit.load_from(res_weight, n_block=bname, n_unit=uname)
                        

def np2th(weights, conv=False):
    """Possibly convert HWIO to OIHW."""
    if conv:
        weights = weights.transpose([3, 2, 0, 1])
    return torch.from_numpy(weights)


def swish(x):
    return x * torch.sigmoid(x)                        
ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}


In [7]:
class Attention(nn.Module):
    def __init__(self, config, vis):
        super(Attention, self).__init__()
        self.vis = vis
        self.num_attention_heads = config.transformer["num_heads"]
        self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = Linear(config.hidden_size, self.all_head_size)
        self.key = Linear(config.hidden_size, self.all_head_size)
        self.value = Linear(config.hidden_size, self.all_head_size)

        self.out = Linear(config.hidden_size, config.hidden_size)
        self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
        self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])

        self.softmax = Softmax(dim=-1)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states):
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        attention_probs = self.softmax(attention_scores)
        weights = attention_probs if self.vis else None
        attention_probs = self.attn_dropout(attention_probs)

        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        attention_output = self.out(context_layer)
        attention_output = self.proj_dropout(attention_output)
        return attention_output, weights


class Mlp(nn.Module):
    def __init__(self, config):
        super(Mlp, self).__init__()
        self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"])
        self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size)
        self.act_fn = ACT2FN["gelu"]
        self.dropout = Dropout(config.transformer["dropout_rate"])

        self._init_weights()

    def _init_weights(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.normal_(self.fc1.bias, std=1e-6)
        nn.init.normal_(self.fc2.bias, std=1e-6)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act_fn(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x
    
class Embeddings(nn.Module):
    """Construct the embeddings from patch, position embeddings.
    """
    def __init__(self, config, img_size, in_channels=3):
        super(Embeddings, self).__init__()
        self.hybrid = None
        self.config = config
        img_size = _pair(img_size)

        if config.patches.get("grid") is not None:   # ResNet
            grid_size = config.patches["grid"]
            patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])
            patch_size_real = (patch_size[0] * 16, patch_size[1] * 16)
            n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1])  
            self.hybrid = True
        else:
            patch_size = _pair(config.patches["size"])
            n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
            self.hybrid = False

        if self.hybrid:
            self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor)
            in_channels = self.hybrid_model.width * 16
        self.patch_embeddings = Conv2d(in_channels=in_channels,
                                       out_channels=config.hidden_size,
                                       kernel_size=patch_size,
                                       stride=patch_size)
        self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config.hidden_size))

        self.dropout = Dropout(config.transformer["dropout_rate"])


    def forward(self, x):
        if self.hybrid:
            x, features = self.hybrid_model(x)
        else:
            features = None
        x = self.patch_embeddings(x)  # (B, hidden. n_patches^(1/2), n_patches^(1/2))
        x = x.flatten(2)
        x = x.transpose(-1, -2)  # (B, n_patches, hidden)

        embeddings = x + self.position_embeddings
        embeddings = self.dropout(embeddings)
        return embeddings, features


class Block(nn.Module):
    def __init__(self, config, vis):
        super(Block, self).__init__()
        self.hidden_size = config.hidden_size
        self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
        self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
        self.ffn = Mlp(config)
        self.attn = Attention(config, vis)

    def forward(self, x):
        h = x
        x = self.attention_norm(x)
        x, weights = self.attn(x)
        x = x + h

        h = x
        x = self.ffn_norm(x)
        x = self.ffn(x)
        x = x + h
        return x, weights

    def load_from_mae(self, weights, n_block):
        ROOT = f"Transformer/encoderblock_{n_block}"
        with torch.no_grad():
            import pdb 
            
#             'blocks.0.norm1.weight'   'blocks.0.attn.qkv.weight'  'blocks.0.attn.proj.weight' 'blocks.0.norm2.
# 'blocks.0.mlp.fc1.w' 'blocks.0.mlp.fc2.weight'
            
            query_weight = weights[f'blocks.{n_block}.attn.qkv.weight'][:self.hidden_size] 
            key_weight = weights[f'blocks.{n_block}.attn.qkv.weight'][self.hidden_size: 2*self.hidden_size] 
            value_weight = weights[f'blocks.{n_block}.attn.qkv.weight'][2*self.hidden_size:] 
            
            
            out_weight = weights[f'blocks.{n_block}.attn.proj.weight']

            query_bias = weights[f'blocks.{n_block}.attn.qkv.bias'][:self.hidden_size] 
            key_bias =  weights[f'blocks.{n_block}.attn.qkv.bias'][self.hidden_size: 2*self.hidden_size] 
            value_bias =  weights[f'blocks.{n_block}.attn.qkv.bias'][2*self.hidden_size:] 
            
            out_bias = weights[f'blocks.{n_block}.attn.proj.bias']
            
            assert self.attn.query.weight.shape == query_weight.shape
            assert self.attn.key.weight.shape == key_weight.shape
            assert self.attn.value.weight.shape == value_weight.shape
            assert self.attn.out.weight.shape == out_weight.shape
            self.attn.query.weight.copy_(query_weight)
            self.attn.key.weight.copy_(key_weight)
            self.attn.value.weight.copy_(value_weight)
            self.attn.out.weight.copy_(out_weight)
            
            
            self.attn.query.bias.copy_(query_bias)
            self.attn.key.bias.copy_(key_bias)
            self.attn.value.bias.copy_(value_bias)
            self.attn.out.bias.copy_(out_bias)

            mlp_weight_0 = weights[f'blocks.{n_block}.mlp.fc1.weight']
            mlp_weight_1 = weights[f'blocks.{n_block}.mlp.fc2.weight']
            mlp_bias_0 = weights[f'blocks.{n_block}.mlp.fc1.bias']
            mlp_bias_1 = weights[f'blocks.{n_block}.mlp.fc2.bias']

            assert self.ffn.fc1.weight.shape == mlp_weight_0.shape
            assert self.ffn.fc2.weight.shape == mlp_weight_1.shape
            assert self.ffn.fc1.bias.shape == mlp_bias_0.shape
            assert self.ffn.fc2.bias.shape == mlp_bias_1.shape
            
            self.ffn.fc1.weight.copy_(mlp_weight_0)
            self.ffn.fc2.weight.copy_(mlp_weight_1)
            self.ffn.fc1.bias.copy_(mlp_bias_0)
            self.ffn.fc2.bias.copy_(mlp_bias_1)

            self.attention_norm.weight.copy_(weights[f'blocks.{n_block}.norm1.weight'])
            self.attention_norm.bias.copy_(weights[f'blocks.{n_block}.norm1.bias'])
            self.ffn_norm.weight.copy_(weights[f'blocks.{n_block}.norm2.weight'])
            self.ffn_norm.bias.copy_(weights[f'blocks.{n_block}.norm2.bias'])


class Encoder(nn.Module):
    def __init__(self, config, vis):
        super(Encoder, self).__init__()
        self.vis = vis
        self.layer = nn.ModuleList()
        self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)
        for _ in range(config.transformer["num_layers"]):
            layer = Block(config, vis)
            self.layer.append(copy.deepcopy(layer))

    def forward(self, hidden_states):
        attn_weights = []
        for layer_block in self.layer:
            hidden_states, weights = layer_block(hidden_states)
    
            if self.vis:
                attn_weights.append(weights)
        # encoded = self.encoder_norm(hidden_states)
        return hidden_states, attn_weights
        return encoded, attn_weights

class Transformer(nn.Module):
    def __init__(self, config, img_size, vis):
        super(Transformer, self).__init__()
        self.embeddings = Embeddings(config, img_size=img_size)
        self.encoder = Encoder(config, vis)

    def forward(self, input_ids):
        embedding_output, features = self.embeddings(input_ids)
        encoded, attn_weights = self.encoder(embedding_output)  # (B, n_patch, hidden)
        return encoded, attn_weights, features
   

In [8]:
# 'blocks.0.norm1.weight'   'blocks.0.attn.qkv.weight'  'blocks.0.attn.proj.weight' 'blocks.0.norm2.
# 'blocks.0.mlp.fc1.w' 'blocks.0.mlp.fc2.weight'

In [9]:
net = VisionTransformer(get_b16_config(), img_size=224, num_classes=14)
net.load_from_mae(checkpoint_model)

layer 0
layer 1
layer 2
layer 3
layer 4
layer 5
layer 6
layer 7
layer 8
layer 9
layer 10
layer 11


In [10]:
model_mae.eval()
net.eval()

VisionTransformer(
  (transformer): Transformer(
    (embeddings): Embeddings(
      (patch_embeddings): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): Encoder(
      (layer): ModuleList(
        (0): Block(
          (attention_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (ffn_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (ffn): Mlp(
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (attn): Attention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (out): Linear(in_features=768, out_features=768, bias=True)
     

In [11]:
np.random.seed(12)
input_1 = torch.Tensor(np.random.normal(size=(1, 3, 224, 224)))

In [12]:
res_mae = model_mae.forward_features_transfer(input_1)

In [13]:
res_copy = net.transformer(input_1)[0]

In [16]:
from numpy.testing import assert_almost_equal

In [17]:
assert_almost_equal(res_mae.detach().numpy(), res_copy.detach().numpy())

In [18]:
 res_copy.detach().numpy()

array([[[ 2.28922   ,  0.8574274 , -1.8156488 , ..., -2.8316395 ,
         -1.2804598 ,  0.3285563 ],
        [ 1.9431055 ,  5.518027  , -2.3719282 , ..., -4.5553746 ,
         -1.151378  ,  1.2527792 ],
        [ 3.1565354 ,  2.9455905 , -0.35940456, ..., -0.61962926,
          0.06183657,  2.8153422 ],
        ...,
        [-2.5021086 , -0.01082426, -6.700266  , ..., -1.0508425 ,
          3.5431    , -3.3149319 ],
        [-1.0556322 ,  1.2118303 , -8.897961  , ..., -2.5832372 ,
          4.1201477 , -1.0326865 ],
        [ 4.4617686 ,  2.1963015 , -5.563488  , ..., -5.1784725 ,
          0.46552312, -1.8062646 ]]], dtype=float32)

In [19]:
res_mae.detach().numpy()

array([[[ 2.28922   ,  0.8574274 , -1.8156488 , ..., -2.8316395 ,
         -1.2804598 ,  0.3285563 ],
        [ 1.9431055 ,  5.518027  , -2.3719282 , ..., -4.5553746 ,
         -1.151378  ,  1.2527792 ],
        [ 3.1565354 ,  2.9455905 , -0.35940456, ..., -0.61962926,
          0.06183657,  2.8153422 ],
        ...,
        [-2.5021086 , -0.01082426, -6.700266  , ..., -1.0508425 ,
          3.5431    , -3.3149319 ],
        [-1.0556322 ,  1.2118303 , -8.897961  , ..., -2.5832372 ,
          4.1201477 , -1.0326865 ],
        [ 4.4617686 ,  2.1963015 , -5.563488  , ..., -5.1784725 ,
          0.46552312, -1.8062646 ]]], dtype=float32)

In [20]:
input_2 = torch.Tensor(np.random.normal(size=(1, 3, 224, 224)))
res_mae_2 = model_mae.forward_features_transfer(input_2)
res_copy_2 = net.transformer(input_2)[0]
assert_almost_equal(res_mae_2.detach().numpy(), res_copy_2.detach().numpy())