# Transformer implementation

In [None]:
import math
from collections import OrderedDict
from functools import partial
from typing import Any, Callable, Dict, List, NamedTuple, Optional
from torchvision.ops import *
import torch
import torch.nn as nn

In [None]:
class ConvStemConfig(NamedTuple):
    out_channels: int = 64
    kernel_size: int = 3
    stride: int = 2
    norm_layer: Callable[..., nn.Module] = nn.BatchNorm2d
    activation_layer: Callable[..., nn.Module] = nn.ReLU

In [None]:
def expand_index_like(index: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor:
    """Expands the index along the last dimension of the input tokens.

    Args:
        index:
            Index tensor with shape (batch_size, idx_length) where each entry is
            an index in [0, sequence_length).
        tokens:
            Tokens tensor with shape (batch_size, sequence_length, dim).

    Returns:
        Index tensor with shape (batch_size, idx_length, dim) where the original
        indices are repeated dim times along the last dimension.

    """
    dim = tokens.shape[-1]
    index = index.unsqueeze(-1).expand(-1, -1, dim)
    return index

In [None]:
def get_at_index(tokens: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
    """Selects tokens at index.

    Args:
        tokens:
            Token tensor with shape (batch_size, sequence_length, dim).
        index:
            Index tensor with shape (batch_size, index_length) where each entry is
            an index in [0, sequence_length).

    Returns:
        Token tensor with shape (batch_size, index_length, dim) containing the
        selected tokens.

    """
    index = expand_index_like(index, tokens)
    return torch.gather(tokens, 1, index)

In [None]:
class MLPBlock(MLP):
    """Transformer MLP block."""

    _version = 2

    def __init__(self, in_dim: int, mlp_dim: int, dropout: float):
        super().__init__(in_dim, [mlp_dim, in_dim], activation_layer=nn.GELU, inplace=None, dropout=dropout)

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.normal_(m.bias, std=1e-6)

    def _load_from_state_dict(
        self,
        state_dict,
        prefix,
        local_metadata,
        strict,
        missing_keys,
        unexpected_keys,
        error_msgs,
    ):
        version = local_metadata.get("version", None)

        if version is None or version < 2:
            # Replacing legacy MLPBlock with MLP. See https://github.com/pytorch/vision/pull/6053
            for i in range(2):
                for type in ["weight", "bias"]:
                    old_key = f"{prefix}linear_{i+1}.{type}"
                    new_key = f"{prefix}{3*i}.{type}"
                    if old_key in state_dict:
                        state_dict[new_key] = state_dict.pop(old_key)

        super()._load_from_state_dict(
            state_dict,
            prefix,
            local_metadata,
            strict,
            missing_keys,
            unexpected_keys,
            error_msgs,
        )

In [None]:
class EncoderBlock(nn.Module):
    """Transformer encoder block."""

    def __init__(
        self,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float,
        attention_dropout: float,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
    ):
        super().__init__()
        self.num_heads = num_heads

        # Attention block
        self.ln_1 = norm_layer(hidden_dim)
        self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True)
        self.dropout = nn.Dropout(dropout)

        # MLP block
        self.ln_2 = norm_layer(hidden_dim)
        self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)

    def forward(self, input: torch.Tensor):
        torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
        x = self.ln_1(input)
        x, _ = self.self_attention(x, x, x, need_weights=False)
        x = self.dropout(x)
        x = x + input

        y = self.ln_2(x)
        y = self.mlp(y)
        return x + y

In [None]:
class Encoder(nn.Module):
    """Transformer Model Encoder for sequence to sequence translation."""

    def __init__(
        self,
        seq_length: int,
        num_layers: int,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float,
        attention_dropout: float,
        num_cls_tokens: int,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6)
    ):
        super().__init__()
        self.num_cls_tokens = num_cls_tokens
        self.seq_length = seq_length
        # Note that batch_size is on the first dim because
        # we have batch_first=True in nn.MultiAttention() by default

        self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02))  # from BERT

        self.dropout = nn.Dropout(dropout)
        layers: OrderedDict[str, nn.Module] = OrderedDict()
        for i in range(num_layers):
            layers[f"encoder_layer_{i}"] = EncoderBlock(
                num_heads,
                hidden_dim,
                mlp_dim,
                dropout,
                attention_dropout,
                norm_layer,
            )
        self.layers = nn.Sequential(layers)
        self.ln = norm_layer(hidden_dim)

    def forward(self,
                input: torch.Tensor,
                idx_keep: Optional[torch.Tensor] = None
               ):

        torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")

        input = input + self._interpolate_pos_encoding(input,self.num_cls_tokens)
        if idx_keep is not None:
            input = utils.get_at_index(input, idx_keep)
        return self.ln(self.layers(self.dropout(input)))

    def _interpolate_pos_encoding(self,
                                  input: torch.Tensor,
                                  num_cls_tokens: int = 3):
        """Returns the interpolated positional embedding for the given input.

        This function interpolates self.pos_embedding for all tokens in the input,
        ignoring the class token. This allows encoding variable sized images.

        Args:
            input:
               Input tensor with shape (batch_size, num_sequences).
            num_cls_tokens:
               number of classification tokens prepended to the
        """

        npatch = input.shape[1] - num_cls_tokens
        N = self.pos_embedding.shape[1] - num_cls_tokens
        diff = num_cls_tokens -1
        if npatch == N:
          return self.pos_embedding

        else:
          npatch += diff
          class_emb = self.pos_embedding[:, 1]
          pos_embedding = self.pos_embedding[:, 1+diff:]
        dim = input.shape[-1]
        pos_embedding = nn.functional.interpolate(
            pos_embedding.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
            scale_factor=math.sqrt(npatch / N),
            mode="bicubic",
        )
        pos_embedding = pos_embedding.permute(0, 2, 3, 1).view(1, -1, dim)
        return torch.cat((class_emb.unsqueeze(0), pos_embedding), dim=1)

In [None]:
class VisionTransformer(nn.Module):
    """Vision Transformer as per https://arxiv.org/abs/2010.11929."""

    def __init__(
        self,
        image_size: int = 224,
        patch_size: int = 16,
        num_layers: int = 6,
        num_heads: int = 4,
        hidden_dim: int = 512,
        mlp_dim: int = 512,
        dropout: float = 0.0,
        attention_dropout: float = 0.0,
        num_classes: int = 1000,
        num_cls_tokens: int = 3,
        representation_size: Optional[int] = None,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
        conv_stem_configs: Optional[List[ConvStemConfig]] = None,
    ):
        super().__init__()

        torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!")
        self.image_size = image_size
        self.patch_size = patch_size
        self.hidden_dim = hidden_dim
        self.mlp_dim = mlp_dim
        self.attention_dropout = attention_dropout
        self.dropout = dropout
        self.num_classes = num_classes
        self.representation_size = representation_size
        self.norm_layer = norm_layer
        self.num_cls_tokens = num_cls_tokens

        if conv_stem_configs is not None:
            # As per https://arxiv.org/abs/2106.14881
            seq_proj = nn.Sequential()
            prev_channels = 3
            for i, conv_stem_layer_config in enumerate(conv_stem_configs):
                seq_proj.add_module(
                    f"conv_bn_relu_{i}",
                    Conv2dNormActivation(
                        in_channels=prev_channels,
                        out_channels=conv_stem_layer_config.out_channels,
                        kernel_size=conv_stem_layer_config.kernel_size,
                        stride=conv_stem_layer_config.stride,
                        norm_layer=conv_stem_layer_config.norm_layer,
                        activation_layer=conv_stem_layer_config.activation_layer,
                    ),
                )
                prev_channels = conv_stem_layer_config.out_channels
            seq_proj.add_module(
                "conv_last", nn.Conv2d(in_channels=prev_channels, out_channels=hidden_dim, kernel_size=1)
            )
            self.conv_proj: nn.Module = seq_proj
        else:
            self.conv_proj = nn.Conv2d(
                in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size
            )

        self.seq_length = (image_size // patch_size) ** 2
        # Add a class token
        self.class_token = nn.Parameter(torch.zeros(1, self.num_cls_tokens, hidden_dim))
        self.seq_length += self.num_cls_tokens

        self.encoder = Encoder(
            self.seq_length,
            num_layers,
            num_heads,
            hidden_dim,
            mlp_dim,
            dropout,
            attention_dropout,
            num_cls_tokens,
            norm_layer)
        #self.seq_length = seq_length #+ self.num_cls_tokens

        heads_layers: OrderedDict[str, nn.Module] = OrderedDict()
        if representation_size is None:
            heads_layers["head"] = nn.Linear(hidden_dim, num_classes)
        else:
            heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size)
            heads_layers["act"] = nn.Tanh()
            heads_layers["head"] = nn.Linear(representation_size, num_classes)

        self.heads = nn.Sequential(heads_layers)

        if isinstance(self.conv_proj, nn.Conv2d):
            # Init the patchify stem
            fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1]
            nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in))
            if self.conv_proj.bias is not None:
                nn.init.zeros_(self.conv_proj.bias)
        elif self.conv_proj.conv_last is not None and isinstance(self.conv_proj.conv_last, nn.Conv2d):
            # Init the last 1x1 conv of the conv stem
            nn.init.normal_(
                self.conv_proj.conv_last.weight, mean=0.0, std=math.sqrt(2.0 / self.conv_proj.conv_last.out_channels)
            )
            if self.conv_proj.conv_last.bias is not None:
                nn.init.zeros_(self.conv_proj.conv_last.bias)

        if hasattr(self.heads, "pre_logits") and isinstance(self.heads.pre_logits, nn.Linear):
            fan_in = self.heads.pre_logits.in_features
            nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in))
            nn.init.zeros_(self.heads.pre_logits.bias)

        if isinstance(self.heads.head, nn.Linear):
            nn.init.zeros_(self.heads.head.weight)
            nn.init.zeros_(self.heads.head.bias)

    def _process_input(self, x: torch.Tensor) -> torch.Tensor:
        n, c, h, w = x.shape
        p = self.patch_size
        torch._assert(h == self.image_size, f"Wrong image height! Expected {self.image_size} but got {h}!")
        torch._assert(w == self.image_size, f"Wrong image width! Expected {self.image_size} but got {w}!")
        n_h = h // p
        n_w = w // p

        # (n, c, h, w) -> (n, hidden_dim, n_h, n_w)
        x = self.conv_proj(x)
        # (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w))
        x = x.reshape(n, self.hidden_dim, n_h * n_w)

        # (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim)
        # The self attention layer expects inputs in the format (N, S, E)
        # where S is the source sequence length, N is the batch size, E is the
        # embedding dimension
        x = x.permute(0, 2, 1)

        return x

    def forward(self, x: torch.Tensor,
                branch: str = 'target',
                idx_keep: Optional[torch.Tensor] = None):
        # Reshape and permute the input tensor
        x = self._process_input(x)
        n = x.shape[0]

        # Expand the class token to the full batch
        if branch == 'target':
          batch_class_token = self.class_token.expand(n, -1, -1)
          x = torch.cat([batch_class_token, x], dim=1)
          x = self.encoder(x)
          return x[:,0:self.num_cls_tokens]
        elif branch == 'anchor':
          batch_class_token = self.class_token.expand(n, -1, -1)
          x = torch.cat([batch_class_token, x], dim=1)
          x = torch.cat((x[:,:1],x[:,self.num_cls_tokens:]),dim = 1)
          x = self.encoder(x,idx_keep=idx_keep)
          return x[:,0]

        # Classifier "token" as used by standard language architectures
        # else:
        #x = x[:, 0]
        #x = self.heads(x)

        # return x
# stemconfig = [ConvStemConfig(out_channels = 64, kernel_size = 3 , stride = 2) for i in range(4)]

In [None]:
stemconfig = [ConvStemConfig(out_channels = 64, kernel_size = 3 , stride = 2) for i in range(4)]

keep,mask = utils.random_token_mask((2,197),mask_ratio=0.15)
backbone = VisionTransformer(conv_stem_configs=stemconfig)


# MSN implementation

In [None]:
# ! pip install lightly

In [None]:
import copy
from lightly.models import utils
from typing import Tuple

In [None]:
def count_parameters(model) -> int:
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
def deactivate_requires_grad(model: nn.Module):
    """Deactivates the requires_grad flag for all parameters of a model.

    This has the same effect as permanently executing the model within a `torch.no_grad()`
    context. Use this method to disable gradient computation and therefore
    training for a model.

    Examples:
        >>> backbone = resnet18()
        >>> deactivate_requires_grad(backbone)
    """
    for param in model.parameters():
        param.requires_grad = False

In [None]:
@torch.no_grad()
def update_momentum(model: nn.Module, model_ema: nn.Module, m: float):
    """Updates parameters of `model_ema` with Exponential Moving Average of `model`

    Momentum encoders are a crucial component fo models such as MoCo or BYOL.

    Examples:
        >>> backbone = resnet18()
        >>> projection_head = MoCoProjectionHead()
        >>> backbone_momentum = copy.deepcopy(moco)
        >>> projection_head_momentum = copy.deepcopy(projection_head)
        >>>
        >>> # update momentum
        >>> update_momentum(moco, moco_momentum, m=0.999)
        >>> update_momentum(projection_head, projection_head_momentum, m=0.999)
    """
    for model_ema, model in zip(model_ema.parameters(), model.parameters()):
        model_ema.data = model_ema.data * m + model.data * (1.0 - m)

In [None]:
def random_token_mask(size: Tuple[int, int],
                      mask_ratio: float = 0.6,
                      mask_class_token: bool = False,
                     ) -> torch.Tensor:
    """Creates random token masks.

    Args:
        size:
            Size of the token batch for which to generate masks.
            Should be (batch_size, sequence_length).
        mask_ratio:
            Percentage of tokens to mask.
        mask_class_token:
            If False the class token is never masked. If True the class token
            might be masked.
        device:
            Device on which to create the index masks.

    Returns:
        A (index_keep, index_mask) tuple where each index is a tensor.
        index_keep contains the indices of the unmasked tokens and has shape
        (batch_size, num_keep). index_mask contains the indices of the masked
        tokens and has shape (batch_size, sequence_length - num_keep).
        num_keep is equal to sequence_length * (1- mask_ratio).

    """
    batch_size, sequence_length = size
    num_keep = int(sequence_length * (1 - mask_ratio))

    noise = torch.rand(batch_size, sequence_length)
    if not mask_class_token and sequence_length > 0:
        # make sure that class token is not masked
        noise[:, 0] = -1
        num_keep = max(1, num_keep)

    # get indices of tokens to keep
    indices = torch.argsort(noise, dim=1)
    idx_keep = indices[:, :num_keep]
    idx_mask = indices[:, num_keep:]

    return idx_keep, idx_mask

In [None]:
class DenseBlock(nn.Module):
  def __init__(self,
               in_features: int = 768,
               out_features: int = 2048,
               bias: bool = False,
              ) -> None:

    super().__init__()

    self.dense_block = nn.Sequential(nn.Linear(in_features= in_features,
                                               out_features= out_features,
                                               bias=bias),
                                     nn.LayerNorm(normalized_shape= out_features),
                                     nn.GELU()
                                    )
  def forward(self,
              x: torch.Tensor
              ) -> torch.tensor:

    x = self.dense_block(x)

    return x

In [None]:
y = torch.rand(2,2,768)
DenseBlock(768,768*2)(y).shape

torch.Size([2, 2, 1536])

In [None]:
class ProjectionHead(nn.Module):
  def __init__(self,
               in_features: int = 512,
               hidden_features: int = 2048,
               out_features: int = 512,
               bias : bool = False
               ) -> None:

    super().__init__()

    self.projection_head = nn.Sequential(DenseBlock(in_features= in_features,
                                                    out_features= hidden_features,
                                                    bias= bias
                                                    ),
                                         DenseBlock(in_features= hidden_features,
                                                    out_features= hidden_features,
                                                    bias= bias
                                                    ),
                                         nn.Linear(in_features= hidden_features,
                                                   out_features= out_features,
                                                   bias= bias
                                                   )
                                         )

  def forward(self,
              x: torch.Tensor
              ) -> torch.Tensor:
    x = self.projection_head(x)
    return x

In [None]:
x = torch.randn(2,3,224,224)

In [None]:
backbone = VisionTransformer()

In [None]:
class MSN(nn.Module):
  def __init__(self,
               backbone: nn.Module,
               masking_ratio: float = 0.15,
               ema_p: float = 0.996
              ) -> None:
    super().__init__()

    self.masking_ratio = masking_ratio
    self.ema_p = ema_p

    self.anchor_backbone = backbone
    self.anchor_projection_head = ProjectionHead()

    self.target_backbone = copy.deepcopy(self.anchor_backbone)
    self.target_projection_head = copy.deepcopy(self.anchor_projection_head)

    deactivate_requires_grad(self.target_backbone)
    deactivate_requires_grad(self.target_projection_head)



  def forward(self,
              views:list[torch.tensor]
              ) -> tuple[torch.Tensor]:

    update_momentum(model= self.anchor_backbone,
                    model_ema= self.target_backbone,
                    m = self.ema_p,
                   )
    update_momentum(model= self.anchor_projection_head,
                    model_ema= self.target_projection_head,
                    m = self.ema_p,
                   )

    target_projections = self._target_forward(views[0])
    anchor_projections_sim = self._anchor_forward(views[1])
    anchor_projections_age = self._anchor_forward(views[2])
    anchor_projections_gender = self._anchor_forward(views[3])


    anchor_projections = torch.stack((anchor_projections_sim,
                                     anchor_projections_age,
                                     anchor_projections_gender
                                     ),
                                     dim= 1
                                     )

    return (target_projections,
            anchor_projections
            )

  def _target_forward(self,
                      view: torch.tensor
                      ) -> torch.Tensor:

    target_encodings = self.target_backbone(x= view,
                                            branch='target'
                                            )
    target_projections = self.target_projection_head(x= target_encodings)

    return target_projections


  def _anchor_forward(self,
                      view: torch.tensor
                      ) -> torch.Tensor:

    batch_size, _, _, width = view.shape
    seq_length = (width // self.anchor_backbone.patch_size) ** 2
    idx_keep, idx_mask = random_token_mask(size= (view.shape[0],seq_length),
                                           mask_ratio= self.masking_ratio
                                          )

    anchor_encodings = self.anchor_backbone(x= view,
                                            branch= 'anchor',
                                            idx_keep= None)
    anchor_projections = self.anchor_projection_head(x= anchor_encodings)

    return anchor_projections






In [None]:
MSN(backbone)([x,x-.01,x-.02,x-.03])

(tensor([[[-0.3703, -0.2485, -1.0448,  ..., -0.1906, -1.1420,  0.0873],
          [-0.3931,  0.2382, -0.6288,  ...,  0.0809, -0.4583,  0.0217],
          [-0.3198,  0.5406, -0.7428,  ...,  0.1183, -0.4445,  0.0309]],
 
         [[ 0.1449,  0.0282, -0.4336,  ...,  0.2334, -0.6158, -0.0925],
          [-0.2670,  0.1333, -0.7477,  ...,  0.3959, -0.2559,  0.4295],
          [-0.1163,  0.4009, -0.6852,  ...,  0.2619, -0.0704,  0.0533]]]),
 tensor([[[-0.3588,  0.2480, -0.6336,  ...,  0.0374, -0.4314, -0.0163],
          [-0.3013,  0.2253, -0.6751,  ...,  0.0693, -0.4367, -0.0513],
          [-0.2425,  0.2062, -0.7045,  ...,  0.1065, -0.4489, -0.0964]],
 
         [[-0.1811,  0.0526, -0.7377,  ...,  0.3825, -0.4510,  0.3992],
          [-0.1445,  0.0192, -0.7054,  ...,  0.4038, -0.6250,  0.3425],
          [-0.1111,  0.0163, -0.6744,  ...,  0.4076, -0.7796,  0.2513]]],
        grad_fn=<StackBackward0>))

In [None]:
from lightly.models.modules.masked_autoencoder import MAEBackbone
from torchvision.models import vit_b_16

In [None]:
from __future__ import annotations

import math
from functools import partial
from typing import Callable, List, Optional

import torch
import torch.nn as nn

# vision_transformer requires torchvision >= 0.12
from torchvision.models import vision_transformer
from torchvision.models.vision_transformer import ConvStemConfig

from lightly.models import utils

In [None]:
class MSNEncoder(vision_transformer.Encoder):
    """Encoder for the Masked Autoencoder model [0].

    Encodes patch embeddings. Code inspired by [1].

    - [0]: Masked Autoencoder, 2021, https://arxiv.org/abs/2111.06377
    - [1]: https://github.com/facebookresearch/mae

    Attributes:
        seq_length:
            Token sequence length, including the class token.
        num_layers:
            Number of transformer blocks.
        num_heads:
            Number of attention heads.
        hidden_dim:
            Dimension of the input and output tokens.
        mlp_dim:
            Dimension of the MLP in the transformer block.
        dropout:
            Percentage of elements set to zero after the MLP in the transformer.
        attention_dropout:
            Percentage of elements set to zero after the attention head.

    """

    def __init__(
        self,
        seq_length: int = 196,
        num_layers: int = 6,
        num_heads: int = 4,
        hidden_dim: int = 512,
        mlp_dim: int = 512,
        dropout: float = 0.0,
        attention_dropout: float = 0.0,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
    ):
        super().__init__(
            seq_length=seq_length,
            num_layers=num_layers,
            num_heads=num_heads,
            hidden_dim=hidden_dim,
            mlp_dim=mlp_dim,
            dropout=dropout,
            attention_dropout=attention_dropout,
            norm_layer=norm_layer,
        )

    @classmethod
    def from_vit_encoder(cls, vit_encoder: vision_transformer.Encoder) -> MSNEncoder:
        """Creates a MSNEncoder from a torchvision ViT encoder."""
        # Create a new instance with dummy values as they will be overwritten
        # by the copied vit_encoder attributes
        encoder = cls(seq_length=1,
                      num_layers=1,
                      num_heads=1,
                      hidden_dim=1,
                      mlp_dim=1,
                      dropout=0,
                      attention_dropout=0)
        encoder.pos_embedding = vit_encoder.pos_embedding
        encoder.dropout = vit_encoder.dropout
        encoder.layers = vit_encoder.layers
        encoder.ln = vit_encoder.ln
        return encoder


    def forward(
        self, input: torch.Tensor, idx_keep: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """Encode input tokens.

        Args:
            input:
                Batch of token sequences.
            idx_keep:
                Tensor with shape (batch_size, num_tokens_to_keep) where each
                entry is an index of the token to keep in the respective batch.
                If specified, only the indexed tokens will be encoded.

        Returns:
            Batch of encoded output tokens.
        """
        input = input + self.interpolate_pos_encoding(input)
        if idx_keep is not None:
            input = utils.get_at_index(input, idx_keep)
        return self.ln(self.layers(self.dropout(input)))



    def interpolate_pos_encoding(self,
                                 input: torch.Tensor,
                                 num_cls_tokens: int = 3):
        """Returns the interpolated positional embedding for the given input.

        This function interpolates self.pos_embedding for all tokens in the input,
        ignoring the class token. This allows encoding variable sized images.

        Args:
            input:
               Input tensor with shape (batch_size, num_sequences).

        """

        npatch = input.shape[1] - 1
        N = self.pos_embedding.shape[1] - 1
        if npatch == N:
            return self.pos_embedding
        class_emb = self.pos_embedding[:, 0]
        pos_embedding = self.pos_embedding[:, 1:]
        dim = input.shape[-1]
        pos_embedding = nn.functional.interpolate(
            pos_embedding.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(
                0, 3, 1, 2
            ),
            scale_factor=math.sqrt(npatch / N),
            mode="bicubic",
        )
        pos_embedding = pos_embedding.permute(0, 2, 3, 1).view(1, -1, dim)
        return torch.cat((class_emb.unsqueeze(0), pos_embedding), dim=1)

In [None]:
class MAEBackbone(vision_transformer.VisionTransformer):
    """Backbone for the Masked Autoencoder model [0].

    Converts images into patches and encodes them. Code inspired by [1].
    Note that this implementation uses a learned positional embedding while [0]
    uses a fixed positional embedding.

    - [0]: Masked Autoencoder, 2021, https://arxiv.org/abs/2111.06377
    - [1]: https://github.com/facebookresearch/mae
    - [2]: Early Convolutions Help Transformers See Better, 2021, https://arxiv.org/abs/2106.14881.

    Attributes:
        image_size:
            Input image size.
        patch_size:
            Width and height of the image patches. image_size must be a multiple
            of patch_size.
        num_layers:
            Number of transformer blocks.
        num_heads:
            Number of attention heads.
        hidden_dim:
            Dimension of the input and output tokens.
        mlp_dim:
            Dimension of the MLP in the transformer block.
        dropout:
            Percentage of elements set to zero after the MLP in the transformer.
        attention_dropout:
            Percentage of elements set to zero after the attention head.
        num_classes:
            Number of classes for the classification head. Currently not used.
        representation_size:
            If specified, an additional linear layer is added before the
            classification head to change the token dimension from hidden_dim
            to representation_size. Currently not used.
        norm_layer:
            Callable that creates a normalization layer.
        conv_stem_configs:
            If specified, a convolutional stem is added at the beggining of the
            network following [2]. Not used in the original Masked Autoencoder
            paper [0].

    """

    def __init__(
        self,
        image_size: int,
        patch_size: int,
        num_layers: int,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float = 0,
        attention_dropout: float = 0,
        num_classes: int = 1000,
        representation_size: Optional[int] = None,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
        conv_stem_configs: Optional[List[ConvStemConfig]] = None,
    ):
        super().__init__(
            image_size=image_size,
            patch_size=patch_size,
            num_layers=num_layers,
            num_heads=num_heads,
            hidden_dim=hidden_dim,
            mlp_dim=mlp_dim,
            dropout=dropout,
            attention_dropout=attention_dropout,
            num_classes=num_classes,
            representation_size=representation_size,
            norm_layer=norm_layer,
            conv_stem_configs=conv_stem_configs,
        )
        self.encoder = MSNEncoder(
            seq_length=self.seq_length,
            num_layers=num_layers,
            num_heads=num_heads,
            hidden_dim=hidden_dim,
            mlp_dim=mlp_dim,
            dropout=dropout,
            attention_dropout=attention_dropout,
            norm_layer=norm_layer,
        )

    @classmethod
    def from_vit(cls, vit: vision_transformer.VisionTransformer) -> MAEBackbone:
        """Creates a MAEBackbone from a torchvision ViT model."""
        # Create a new instance with dummy values as they will be overwritten
        # by the copied vit_encoder attributes
        backbone = cls(
            image_size=vit.image_size,
            patch_size=vit.patch_size,
            num_layers=1,
            num_heads=1,
            hidden_dim=vit.hidden_dim,
            mlp_dim=vit.mlp_dim,
            dropout=vit.dropout,
            attention_dropout=vit.attention_dropout,
            num_classes=vit.num_classes,
            representation_size=vit.representation_size,
            norm_layer=vit.norm_layer,
        )
        backbone.conv_proj = vit.conv_proj
        backbone.class_token = vit.class_token
        backbone.seq_length = vit.seq_length
        backbone.heads = vit.heads
        backbone.encoder = MAEEncoder.from_vit_encoder(vit.encoder)
        return backbone

    def forward(
        self, images: torch.Tensor, idx_keep: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """Returns encoded class tokens from a batch of images.

        Args:
            images:
                Tensor with shape (batch_size, channels, image_size, image_size).
            idx_keep:
                Tensor with shape (batch_size, num_tokens_to_keep) where each
                entry is an index of the token to keep in the respective batch.
                If specified, only the indexed tokens will be passed to the
                encoder.

        Returns:
            Tensor with shape (batch_size, hidden_dim) containing the
            encoded class token for every image.

        """
        out = self.encode(images, idx_keep)
        class_token = out[:, 0]
        return class_token

    def encode(
        self, images: torch.Tensor, idx_keep: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """Returns encoded class and patch tokens from images.

        Args:
            images:
                Tensor with shape (batch_size, channels, image_size, image_size).
            idx_keep:
                Tensor with shape (batch_size, num_tokens_to_keep) where each
                entry is an index of the token to keep in the respective batch.
                If specified, only the indexed tokens will be passed to the
                encoder.

        Returns:
            Tensor with shape (batch_size, sequence_length, hidden_dim)
            containing the encoded class and patch tokens for every image.

        """
        out = self.images_to_tokens(images, prepend_class_token=True)
        return self.encoder(out, idx_keep)

    def images_to_tokens(
        self, images: torch.Tensor, prepend_class_token: bool
    ) -> torch.Tensor:
        """Converts images into patch tokens.

        Args:
            images:
                Tensor with shape (batch_size, channels, image_size, image_size).

        Returns:
            Tensor with shape (batch_size, sequence_length - 1, hidden_dim)
            containing the patch tokens.
        """
        x = self.conv_proj(images)
        tokens = x.flatten(2).transpose(1, 2)
        if prepend_class_token:
            tokens = utils.prepend_class_token(tokens, self.class_token)
        return tokens

In [None]:
MAEBackbone.from_vit(backbone)(x).shape

torch.Size([2, 512])

In [None]:
import numpy as np
import copy
from tqdm import tqdm, trange
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader

from torchvision.transforms import ToTensor
from torchvision.datasets.mnist import MNIST

np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7fb045fec710>

In [None]:
class PatchEmbeddings(nn.Module):
    def __init__(self,
                 images_size: int =  224,
                 patch_size: int = 16,
                 num_channels: int = 3,
                 embedding_dim: int = 512
                ) -> None:
      super().__init__()
      """
        Role: convert the input image of size = (b,c,h,w) into patches sequence
              of size = (b,p,embedding_dim) where
              b = batch size
              c = number of channels
              h = height
              w = width

        Inputs:
            image_size - Dimensionality of the input image (width or height)
            patch_size - Dimensionality of the image patch (width or height)
            embedding_dim - Dimensionality of the patch embedding
            num_channels - Number of channels in the input images (usually 3)
      """
      self.image_size = images_size
      self.patch_size = patch_size
      self.num_channels = num_channels
      self.hidden_size = embedding_dim

      self.num_patches = (self.image_size // self.patch_size) ** 2

      self.projection_layer = nn.Conv2d(in_channels= self.num_channels,
                                        out_channels= self.hidden_size,
                                        kernel_size= self.patch_size,
                                        stride= self.patch_size
                                       )

    def forward(self,
                x: torch.Tensor
                ) -> torch.Tensor:
      x = self.projection_layer(x)
      x = x.flatten(2).transpose(1,2)
      return x

In [None]:
class InputEmbeddings(nn.Module):
    def __init__(self,
                 images_size: int =  224,
                 num_channels: int = 3,
                 patch_size: int = 16,
                 embedding_dim: int = 512,
                 num_cls_tokens: int = 3,
                 dropout_p: float = 0.0,
                 learnable_pos_encoding: bool = True
                ) -> None:
      super().__init__()
      """
        Role: convert the input image into patches as well as adding CLS token
        and positional encoding

        Inputs:
            image_size - Dimensionality of the input image (width or height)
            num_channels - Number of channels in the input images (usually 3)
            patch_size - Dimensionality of the image patch (width or height)
            embedding_dim - Dimensionality of the patch embedding
            num_cls_tokens - number of classification tokens to add to the sequence
            dropout_p -  percentage of applied dropout
            learnable_pos_encoding - specify typ of positional encoding (fixed or learnable)
      """
      self.image_size = images_size
      self.patch_size = patch_size
      self.num_channels = num_channels
      self.embedding_dim = embedding_dim
      self.num_cls_tokens = num_cls_tokens
      self.dropout_p = dropout_p
      self.learnable_pos_encoding = learnable_pos_encoding

      self.patch_embeddings = PatchEmbeddings(images_size= self.image_size,
                                              patch_size= self.patch_size,
                                              num_channels= self.num_channels,
                                              embedding_dim= self.embedding_dim
                                              )

      self.cls_token = nn.Parameter(torch.randn(size= (1, self.num_cls_tokens, self.embedding_dim)))

      num_patches = self.patch_embeddings.num_patches
      if self.learnable_pos_encoding:
        num_patches = self.patch_embeddings.num_patches
        self.positional_embeddings = nn.Parameter(torch.empty(1, num_patches , self.embedding_dim).normal_(std=0.02))

      else:
        self.positional_embeddings = self._get_positional_embeddings(num_patches=num_patches,
                                                                     embedding_dim=self.embedding_dim
                                                                    )

      self.dropout = nn.Dropout(self.dropout_p)

    def forward(self,
                x: torch.Tensor
                ) -> torch.Tensor:

      x = self.patch_embeddings(x)
      batch_size, _, _ = x.size()
      cls_token = self.cls_token.expand(batch_size,-1,-1)
      x = torch.cat([cls_token,x], dim=1)
      x = x + self.positional_embeddings
      x = self.dropout(x)
      return x


    # def _get_positional_embeddings(self,
    #                              num_patches: int,
    #                              embedding_dim: int
    #                              ) -> torch.Tensor:
    #   poaitional_embeddings  = torch.ones(num_patches + self.num_cls_tokens, embedding_dim)
    #   for i in range(num_patches + self.num_cls_tokens):
    #       for j in range(embedding_dim):
    #           poaitional_embeddings [i][j] = np.sin(i / (10000 ** (j / embedding_dim))) \
    #           if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / embedding_dim)))
    #   return poaitional_embeddings.unsqueeze(0)


In [None]:
class FeedForwardNet(nn.Module):
    def __init__(self,
               embedding_dim: int = 512,
               hidden_dim = 192,
               dropout_p: float = 0.0
               ) -> None:
      super().__init__()
      """
        Role: Feed forward netork of attention head

        Inputs:
            embedding_dim - Dimensionality of the patch embedding
            hidden_dim - Dimensionality of hidden layer in feed-forward network
            dropout_p -  percentage of applied dropout
            learnable_pos_encoding - specify typ of positional encoding (fixed or learnable)
      """
      self.feed_forward = nn.Sequential(nn.Linear(in_features= embedding_dim,
                                                  out_features= hidden_dim
                                                 ),
                                        nn.GELU(),
                                        nn.Linear(in_features=hidden_dim,
                                                  out_features=embedding_dim
                                                 ),
                                        nn.Dropout(dropout_p)
                                       )
    def forward(self,
              x: torch.tensor
              )-> torch.Tensor:
      x = self.feed_forward(x)
      return x

In [None]:
class AttentionBlock(nn.Module):
    def __init__(self,
                 embedding_dim: int = 512,
                 hidden_dim: int = 192,
                 num_heads: int = 4,
                 dropout_p=0.0,
                 pre_layer_norm: bool= True
                 ) -> None:
      super().__init__()

      """ Role: capture the interaction between
          Inputs:
              embedding_dim - Dimensionality of input and attention feature vectors
              hidden_dim - Dimensionality of hidden layer in feed-forward network
              num_heads - Number of heads to use in the Multi-Head Attention block
              dropout - Amount of dropout to apply in the feed-forward network
              pre_layer_norm - Specifies where to apply the layer norm (before or after)
      """
      self.pre_layer_norm = pre_layer_norm
      self.layer_norm_1 = nn.LayerNorm(normalized_shape= embedding_dim)

      self.attention_head = nn.MultiheadAttention(embed_dim= embedding_dim,
                                                  num_heads= num_heads,
                                                  dropout= dropout_p)

      self.layer_norm_2 = nn.LayerNorm(normalized_shape= embedding_dim)

      self.feed_forward = FeedForwardNet(embedding_dim= embedding_dim,
                                         hidden_dim= hidden_dim,
                                         dropout_p= dropout_p
                                        )

    #attention weights from here
    def forward(self,
                x: torch.tensor
                )-> torch.Tensor:

        if self.pre_layer_norm:
          normalization_out1 = self.layer_norm_1(x)
          attention_out = self.attention_head(normalization_out1, normalization_out1, normalization_out1)[0]
          addition_out1 = x + attention_out
          normalization_out2 = self.layer_norm_2(addition_out1)
          feed_forward_out = self.feed_forward(normalization_out2)
          layer_out = addition_out1 + feed_forward_out

        else:
          attention_out = self.attention_head(x,x,x)[0]
          addition_out1 = x + attention_out
          normalization_out = self.layer_norm_1(addition_out1)
          feed_forward_out = self.feed_forward(normalization_out)
          addition_out2 = normalization_out + feed_forward_out
          layer_out = self.layer_norm_2(addition_out2)

        return attention_out

In [None]:
class VisionTransformer(nn.Module):
    def __init__(self,
                 num_layers: int = 6,
                 num_heads: int = 4,
                 hidden_dim: int = 192,
                 pre_layer_norm: bool= True,
                 images_size: int =  224,
                 num_channels: int = 3,
                 patch_size: int = 16,
                 embedding_dim: int = 512,
                 num_cls_tokens: int = 3,
                 dropout_p: float = 0.0,
                 learnable_pos_encoding: bool = False
                ) -> None:
      super().__init__()


      self.image_size = images_size
      self.patch_size = patch_size
      self.num_channels = num_channels
      self.embedding_dim = embedding_dim
      self.num_cls_tokens = num_cls_tokens
      self.dropout_p = dropout_p
      self.learnable_pos_encoding = learnable_pos_encoding

      self.input_embeddings = PatchEmbeddings(images_size= self.image_size,
                                              num_channels= self.num_channels,
                                              patch_size= self.patch_size,
                                              embedding_dim= self.embedding_dim,
                                              )
                                              # num_cls_tokens= self.num_cls_tokens,
                                              # dropout_p=self.dropout_p,
                                              # learnable_pos_encoding=self.learnable_pos_encoding)



      self.num_layers = num_layers
      self.num_heads = num_heads
      self.hidden_dim = hidden_dim
      self.pre_layer_norm = pre_layer_norm

      self.attention_block = AttentionBlock(embedding_dim= self.embedding_dim,
                                            hidden_dim= self.hidden_dim,
                                            num_heads= self.num_heads,
                                            dropout_p=self.dropout_p,
                                            pre_layer_norm=self.pre_layer_norm)


      self.cls_token = nn.Parameter(torch.randn(size= (1, self.num_cls_tokens, self.embedding_dim)))
      num_patches = self.input_embeddings.num_patches
      self.positional_embeddings = nn.Parameter(torch.empty(1, num_patches + self.num_cls_tokens, self.embedding_dim).normal_(std=0.02))

      self.encoder = nn.Sequential(*[self.attention_block for layer in range(num_layers)])


    def forward(self,
                x: torch.tensor,
                branch: str = 'target'
                )-> torch.Tensor:

      if branch == 'target':
        x = self.input_embeddings(x)
        batch_size, _, _ = x.size()
        cls_token = self.cls_token.expand(batch_size,-1,-1)
        x = torch.cat([cls_token,x], dim=1)
        x = x + self.positional_embeddings
        x = self.encoder(x)
        return x#[:,0:self.num_cls_tokens]

      elif branch == 'anchor':
        x = self.input_embeddings(x)
        batch_size, _, _ = x.size()
        cls_token = self.cls_token.expand(batch_size,-1,-1)
        x = torch.cat([cls_token,x], dim=1)
        x = x + self.positional_embeddings
        x = torch.cat((x[:,:1],x[:,self.num_cls_tokens:]),dim = 1)
        x = self.encoder(x)
        return x#[:,0]