In [38]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

In [39]:
def count_trainable_params(model: nn.Module) -> int:
    """
    Count the number of trainable parameters in a PyTorch model.

    Args:
        model (nn.Module): The PyTorch model.

    Returns:
        int: The number of trainable parameters.
    """
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [40]:
backbone = models.efficientnet_b3(
    weights="EfficientNet_B3_Weights.IMAGENET1K_V1"
)

count_trainable_params(backbone.features)

10696232

In [41]:
class TripletNet(nn.Module):
    def __init__(self, embedding_net):
        super(TripletNet, self).__init__()
        self.embedding_net = embedding_net

    def forward(self, x1, x2, x3):
        output1 = self.embedding_net(x1)
        output2 = self.embedding_net(x2)
        output3 = self.embedding_net(x3)
        return output1, output2, output3

    def get_embedding(self, x):
        return self.embedding_net(x)

In [67]:
class ConvNetModel(nn.Module):
    def __init__(self, backbone):
        super(ConvNetModel, self).__init__()
        self.backbone = backbone.features
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc1 = nn.LazyLinear(2048)
        self.dropout1 = nn.Dropout(0.6)
        # self.fc2 = nn.Linear(2048, 2048)
        # self.dropout2 = nn.Dropout(0.6)

    def forward(self, x):
        x = self.backbone(x)
        x = self.global_avg_pool(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        # x = F.relu(self.fc2(x))
        # x = self.dropout2(x)
        x = F.normalize(x, p=2, dim=1)
        return x

In [68]:
conv_net_model = ConvNetModel(backbone)

dummy_sample = torch.randn(1, 3, 224, 224)
output = conv_net_model(dummy_sample)
output.shape

torch.Size([1, 2048])

In [69]:
class DeepRankModel(nn.Module):
    def __init__(self, backbone: nn.Module):
        super(DeepRankModel, self).__init__()
        self.convnet_model = ConvNetModel(backbone)

        self.first_conv = nn.Conv2d(3, 96, kernel_size=8, stride=16, padding=4)
        self.first_maxpool = nn.MaxPool2d(kernel_size=3, stride=4, padding=1)
        
        self.second_conv = nn.Conv2d(3, 96, kernel_size=8, stride=32, padding=4)
        self.second_maxpool = nn.MaxPool2d(kernel_size=7, stride=2, padding=3)

        self.fc = nn.LazyLinear(2048)
    
    def forward(self, first_input, second_input, backbone_input):
        first_conv = self.first_conv(first_input)
        first_max = self.first_maxpool(first_conv)
        first_max = torch.flatten(first_max, 1)
        first_max = F.normalize(first_max, p=2, dim=1)

        second_conv = self.second_conv(second_input)
        second_max = self.second_maxpool(second_conv)
        second_max = torch.flatten(second_max, 1)
        second_max = F.normalize(second_max, p=2, dim=1)

        backbone_output = self.convnet_model(backbone_input)

        merge_one = torch.cat((first_max, second_max), dim=1)
        merge_two = torch.cat((merge_one, backbone_output), dim=1)
        emb = self.fc(merge_two)
        l2_norm_final = F.normalize(emb, p=2, dim=1)

        return l2_norm_final

In [70]:
deep_rank_model = DeepRankModel(backbone)

first_input = torch.randn(32,3,224,224)
second_input = torch.randn(32,3,224,224)
backbone_input = torch.randn(32,3,224,224)

output = deep_rank_model(first_input, second_input, backbone_input)
output.shape

torch.Size([32, 2048])

In [72]:
count_trainable_params(deep_rank_model)

24368872