In [1]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn.parameter import Parameter

class Adapter_ViT(nn.Module):
    """Applies mlp adapter to a vision transformer.

    Args:
        vit_model: a vision transformer model, see base_vit.py
        num_layers: number of hidden layers
        num_classes: how many classes the model output, default to the vit model

    Examples::
        >>> model = timm.create_model("vit_base_patch16_224.orig_in21k_ft_in1k", pretrained=True)
        >>> adapter_model = Adapter_ViT(model, r=4)
        >>> preds = adapter_model(img)
        >>> print(preds.shape)
        torch.Size([1, 1000])
    """
    
    def __init__(self,
                vit_model,
                num_classes: int = 0):
        super(Adapter_ViT, self).__init__()
        
        assert num_classes > 0
        
        for param in vit_model.parameters():
            param.requires_grad = False
        
        self.dim = vit_model.blocks[0].attn.qkv.in_features
        self.adapter = nn.Sequential()
        for t_layer_i in range(len(vit_model.blocks)//2):
            self.adapter.add_module("layer_" + str(t_layer_i), nn.Linear(self.dim, self.dim))
            self.adapter.add_module("relu_" + str(t_layer_i), nn.ReLU())
        self.adapter.add_module("fc", nn.Linear(self.dim, num_classes))
        
        self.backbone = vit_model
        self.backbone.head = self.adapter
        
    def forward(self, x: Tensor) -> Tensor:
        return self.backbone(x)
    

In [None]:
from src.commons.utils_io import load_config
import hydra
from hydra.core.global_hydra import GlobalHydra

GlobalHydra.instance().clear()
list_args=["experiment=mp_naive", "sam_type=small", "data=levir-cd", "data.params.n_shape=3", "data.params.num_worker=0"]
cfg = load_config(list_args)

module = hydra.utils.instantiate(cfg.model.instance)
# model = timm.create_model("vit_base_patch16_224.orig_in21k_ft_in1k", pretrained=True)
adapter_model = Adapter_ViT(model,num_classes=14)