In [1]:
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
import torch
import torch.nn as nn
from torchinfo import summary

In [2]:
def inflate_conv2d_to_conv3d(conv2d, depth_dim=3):
    # conv2d has weights of shape (out_channels, in_channels, H, W)
    weights_2d = conv2d.weight.data
    out_channels, in_channels, H, W = weights_2d.shape
    
    # Inflate the weights by adding the depth as the last dimension
    # Since medical image has channel = 1, then we will need to change back to greyscale
    # The weights are rearranged to shape (out_channels, in_channels, H, W, depth_dim)
    # Then, normalize by dividing by depth_dim
    weights_2d = weights_2d.sum(1, keepdim = True)
    weights_3d = weights_2d.unsqueeze(-1).repeat(1, 1, 1, 1, depth_dim) / depth_dim
    
    # Create a new 3D convolutional layer
    # Note the adjustment in the kernel and stride sizes to accommodate depth as the last dimension
    conv3d = nn.Conv3d(in_channels, out_channels, (H, W, depth_dim), stride=(16, 16, depth_dim))
    
    # Set the weights from the 2D conv layer to the 3D conv layer
    conv3d.weight.data = weights_3d
    
    return conv3d

class Adapter(nn.Module):
    # Add adaptater block
    # go down and then go up 
    def __init__(self, D_features, mlp_ratio=0.25, act_layer=nn.GELU, skip_connect=True):
        super().__init__()
        self.skip_connect = skip_connect
        D_hidden_features = int(D_features * mlp_ratio)
        self.act = act_layer()
        self.D_fc1 = nn.Linear(D_features, D_hidden_features)
        self.D_fc2 = nn.Linear(D_hidden_features, D_features)
        
    def forward(self, x):
        # x is (BT, HW+1, D)
        xs = self.D_fc1(x)
        xs = self.act(xs)
        xs = self.D_fc2(xs)
        if self.skip_connect:
            x = x + xs
        else:
            x = xs
        return x

class AdapterEncoderLayer(nn.Module):
    def __init__(self, original_layer, D_features):
        super().__init__()
        self.original_layer = original_layer
        self.adapter = Adapter(D_features)
    
    def forward(self, *args, **kwargs):
        # Apply original layer
        output = self.original_layer(*args, **kwargs)
        # Apply adapter
        output = self.adapter(output[0])
        return (output,)

class ViT_inflated(nn.Module):
    def __init__(self, num_class, depth_dim):
        super().__init__()
        model = AutoModelForZeroShotImageClassification.from_pretrained("openai/clip-vit-base-patch16")
        ViT_embedding = model.vision_model

        conv2d_layer = ViT_embedding.embeddings.patch_embedding
        conv3d_layer = inflate_conv2d_to_conv3d(conv2d_layer, depth_dim)
        ViT_embedding.embeddings.patch_embedding = conv3d_layer

        for i in range(len(ViT_embedding.encoder.layers)):
            ViT_embedding.encoder.layers[i] = AdapterEncoderLayer(ViT_embedding.encoder.layers[i], 768)
        
        self.model = ViT_embedding
        self.proj_head = nn.Linear(in_features=768, out_features=num_class)

        # Freeze all parameters in the model initially
        for param in self.model.parameters():
            param.requires_grad = False

        # Unfreeze the parameters in the adapter layers and projection head
        for layer in self.model.encoder.layers:
            for param in layer.adapter.parameters():
                param.requires_grad = True
        
        for param in self.proj_head.parameters():
            param.requires_grad = True

    def forward(self, image):
        if not self.model: 
            raise Exception('ViT model does not load successfully')
        
        else:
            print('Model Load successfully, inflated weight to 3D')
        
        output = self.model(image)[1]
        output = self.proj_head(output)

        return output 



In [3]:
model = ViT_inflated(2,150)
image = torch.randn(1,1,224,224,150)
print(model(image).shape)

  return self.fget.__get__(instance, owner)()


Model Load successfully, inflated weight to 3D
torch.Size([1, 2])


In [4]:
summary(model=model, 
        input_size=(1, 1, 224, 224,150), # (batch_size, color_channels, height, width)
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
)

Model Load successfully, inflated weight to 3D


Layer (type (var_name))                                                     Input Shape          Output Shape         Param #              Trainable
ViT_inflated (ViT_inflated)                                                 [1, 1, 224, 224, 150] [1, 2]               --                   Partial
├─CLIPVisionTransformer (model)                                             [1, 1, 224, 224, 150] [1, 768]             --                   Partial
│    └─CLIPVisionEmbeddings (embeddings)                                    [1, 1, 224, 224, 150] [1, 197, 768]        768                  False
│    │    └─Conv3d (patch_embedding)                                        [1, 1, 224, 224, 150] [1, 768, 14, 14, 1]  (29,491,968)         False
│    │    └─Embedding (position_embedding)                                  [1, 197]             [1, 197, 768]        (151,296)            False
│    └─LayerNorm (pre_layrnorm)                                             [1, 197, 768]        [1, 197, 768]        