In [4]:
import sys
sys.path.append("../src")

import os

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet50
from torch.nn.utils.parametrizations import weight_norm
from timm.models.vision_transformer import (
    vit_small_patch8_224,
    vit_small_patch16_224,
    vit_base_patch16_224,
    vit_base_patch8_224
    )

from torchview import draw_graph
from graphviz.graphs import Digraph

from utils import Encoder, ResNet50

In [2]:
class ProjectionHead(nn.Sequential):
    def __init__(self, *args):
        super().__init__(*args)

class Encoder(nn.Module):
    def __init__(
        self, 
        backbone: str,
        hidden_dim: int = 2048, 
        bottleneck_dim: int = 256,
        k_dim: int = 65536,
        num_layers: int = 3
        ):
        super().__init__()

        backbone_dim_table = {
            "resnet50": 2048,
            "vit-s-8": 384,
            "vit-s-16": 384,
            "vit-b-8": 768,
            "vit-b-16": 768
        }

        self.encoder = get_model(backbone)
        embed_dim = backbone_dim_table[backbone]

        self.apply(self._init_weights)
        
        self.mlp = ProjectionHead(
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            *[nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.GELU()) for _ in range(num_layers - 2)],
            nn.Linear(hidden_dim, bottleneck_dim)
        )
        
        self.k_projection = weight_norm(nn.Linear(bottleneck_dim, k_dim, bias=False))
        self.k_projection.parametrizations.weight.original0.data.fill_(1)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.trunc_normal_(module.weight, std=0.02)
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)

    def forward(self, x):
        x = self.encoder(x)
        x = self.mlp(x)
        x = F.normalize(x, p=2, dim=-1)
        x = self.k_projection(x)

        return x
    

class ResNet50(nn.Module):
    """
    ResNet Encoder
    """

    def __init__(self):
        super().__init__()

        model = resnet50()
        backbone = list(model.children())[:-1]
        self.backbone = nn.Sequential(*backbone)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.backbone(x)
        x = torch.flatten(x, 1)

        return x
    

def get_model(backbone):
    model_table = {
        "resnet50": ResNet50(),
        "vit-s-8": vit_small_patch8_224(dynamic_img_size=True),
        "vit-s-16": vit_small_patch16_224(dynamic_img_size=True),
        "vit-b-8": vit_base_patch16_224(dynamic_img_size=True),
        "vit-b-16": vit_base_patch8_224(dynamic_img_size=True)
    }

    assert backbone in model_table, f"backbone must be one of {list(model_table.keys())}"

    model = model_table[backbone]

    if backbone != "resnet50":
        model.fc_norm = nn.Identity()
        model.head_drop = nn.Identity()
        model.head = nn.Identity()

    return model


In [2]:
def generate_graph(backbone: str, destination: str) -> Digraph:
    """
    Creates a graph to visualize the architecture of the Vision Transformer.

    Parameters
    ----------
    variant: str
        The encoder variant.

    destination: str
        The path where the generated graph will be saved.

    Returns
    -------
    graph: Digraph
        A digraph object that visualizes the model.
    """

    encoder = Encoder(backbone)

    model_graph = draw_graph(
        encoder, input_size=(1, 3, 224, 224),
        graph_name="SimCLR",
        expand_nested=True,
        save_graph=True, directory=destination,
        filename=f"{backbone}-architecture"
    )

    graph = model_graph.visual_graph
    
    return graph

In [3]:
model_table = {
    "resnet50": ResNet50(),
    "vit-s-8": vit_small_patch8_224(dynamic_img_size=True),
    "vit-s-16": vit_small_patch16_224(dynamic_img_size=True),
    "vit-b-8": vit_base_patch16_224(dynamic_img_size=True),
    "vit-b-16": vit_base_patch8_224(dynamic_img_size=True)
}

destination = os.path.join("..", "assets", "architechtures")
os.makedirs(destination, exist_ok=True)

for backbone in model_table.keys():
    generate_graph(backbone=backbone, destination=destination)

NameError: name 'ResNet50' is not defined

In [3]:
model = Encoder(backbone="vit-s-16")

model

Encoder(
  (encoder): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (patch_drop): Identity()
    (norm_pre): Identity()
    (blocks): Sequential(
      (0): Block(
        (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=384, out_features=1152, bias=True)
          (q_norm): Identity()
          (k_norm): Identity()
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=384, out_features=384, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (ls1): Identity()
        (drop_path1): Identity()
        (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=384, out_features=1536, bias=True)
          (act): GELU(approximate='none')
  

In [4]:
img = torch.rand(1, 3, 96, 96)

model(img).shape

torch.Size([1, 65536])

In [7]:
model = Encoder(backbone="resnet50")

model

Encoder(
  (encoder): ResNet50(
    (backbone): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        

In [8]:
img = torch.rand(1, 3, 96, 96)

model(img).shape

torch.Size([1, 65536])

In [28]:
model = Encoder()

sum(param.numel() for param in model.parameters())

44017792

In [18]:
model = vit_small_patch16_224(dynamic_img_size=True)

model.forward_features(img).shape

torch.Size([1, 37, 384])

In [16]:
# model = vit_small_patch16_224(dynamic_img_size=True)

# sum(param.numel() for param in model.parameters())

In [19]:
img.shape

torch.Size([1, 3, 96, 96])

In [17]:
img = torch.rand(1, 3, 224, 224)

model(img)

tensor([[-0.0487, -0.0268,  0.0188,  ...,  0.0377, -0.0852, -0.0719]],
       grad_fn=<MmBackward0>)

In [18]:
model(img).shape

torch.Size([1, 65536])

In [19]:
img = torch.rand(1, 3, 96, 96)

model(img)

tensor([[-0.0500, -0.0255,  0.0241,  ...,  0.0315, -0.0823, -0.0726]],
       grad_fn=<MmBackward0>)

In [20]:
model(img).shape

torch.Size([1, 65536])