# Transformer implementation

In [1]:
import math
from collections import OrderedDict
from functools import partial
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple
from types import FunctionType
import torch
import torch.nn as nn
import pytorch_lightning as pl

In [2]:
class ConvStemConfig(NamedTuple):
    out_channels: int = 64
    kernel_size: int = 3
    stride: int = 2
    norm_layer: Callable[..., nn.Module] = nn.BatchNorm2d
    activation_layer: Callable[..., nn.Module] = nn.ReLU

In [3]:
def log_api_usage_once(obj: Any) -> None:

    """
    Logs API usage(module and name) within an organization.
    In a large ecosystem, it's often useful to track the PyTorch and
    TorchVision APIs usage. This API provides the similar functionality to the
    logging module in the Python stdlib. It can be used for debugging purpose
    to log which methods are used and by default it is inactive, unless the user
    manually subscribes a logger via the `SetAPIUsageLogger method <https://github.com/pytorch/pytorch/blob/eb3b9fe719b21fae13c7a7cf3253f970290a573e/c10/util/Logging.cpp#L114>`_.
    Please note it is triggered only once for the same API call within a process.
    It does not collect any data from open-source users since it is no-op by default.
    For more information, please refer to
    * PyTorch note: https://pytorch.org/docs/stable/notes/large_scale_deployments.html#api-usage-logging;
    * Logging policy: https://github.com/pytorch/vision/issues/5052;

    Args:
        obj (class instance or method): an object to extract info from.
    """
    module = obj.__module__
    if not module.startswith("torchvision"):
        module = f"torchvision.internal.{module}"
    name = obj.__class__.__name__
    if isinstance(obj, FunctionType):
        name = obj.__name__
    torch._C._log_api_usage_once(f"{module}.{name}")

In [4]:
def make_ntuple(x: Any, n: int) -> Tuple[Any, ...]:
    """
    Make n-tuple from input x. If x is an iterable, then we just convert it to tuple.
    Otherwise, we will make a tuple of length n, all with value of x.
    reference: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/utils.py#L8

    Args:
        x (Any): input value
        n (int): length of the resulting tuple
    """
    if isinstance(x, collections.abc.Iterable):
        return tuple(x)
    return tuple(repeat(x, n))

In [5]:
def expand_index_like(index: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor:
    """Expands the index along the last dimension of the input tokens.

    Args:
        index:
            Index tensor with shape (batch_size, idx_length) where each entry is
            an index in [0, sequence_length).
        tokens:
            Tokens tensor with shape (batch_size, sequence_length, dim).

    Returns:
        Index tensor with shape (batch_size, idx_length, dim) where the original
        indices are repeated dim times along the last dimension.

    """
    dim = tokens.shape[-1]
    index = index.unsqueeze(-1).expand(-1, -1, dim)
    return index

In [6]:
def get_at_index(tokens: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
    """Selects tokens at index.

    Args:
        tokens:
            Token tensor with shape (batch_size, sequence_length, dim).
        index:
            Index tensor with shape (batch_size, index_length) where each entry is
            an index in [0, sequence_length).

    Returns:
        Token tensor with shape (batch_size, index_length, dim) containing the
        selected tokens.

    """
    index = expand_index_like(index, tokens)
    return torch.gather(tokens, 1, index)

In [52]:
class ConvNormActivation(torch.nn.Sequential):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, Tuple[int, ...]] = 3,
        stride: Union[int, Tuple[int, ...]] = 1,
        padding: Optional[Union[int, Tuple[int, ...], str]] = None,
        groups: int = 1,
        norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
        activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
        dilation: Union[int, Tuple[int, ...]] = 1,
        inplace: Optional[bool] = True,
        bias: Optional[bool] = None,
        conv_layer: Callable[..., torch.nn.Module] = torch.nn.Conv2d,
    ) -> None:

        if padding is None:
            if isinstance(kernel_size, int) and isinstance(dilation, int):
                padding = (kernel_size - 1) // 2 * dilation
            else:
                _conv_dim = len(kernel_size) if isinstance(kernel_size, Sequence) else len(dilation)
                kernel_size = make_ntuple(kernel_size, _conv_dim)
                dilation = make_ntuple(dilation, _conv_dim)
                padding = tuple((kernel_size[i] - 1) // 2 * dilation[i] for i in range(_conv_dim))
        if bias is None:
            bias = norm_layer is None

        layers = [
            conv_layer(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                dilation=dilation,
                groups=groups,
                bias=bias,
            )
        ]

        if norm_layer is not None:
            layers.append(norm_layer(out_channels))

        if activation_layer is not None:
            params = {} if inplace is None else {"inplace": inplace}
            layers.append(activation_layer(**params))
        super().__init__(*layers)
        log_api_usage_once(self)
        self.out_channels = out_channels

        if self.__class__ == ConvNormActivation:
            warnings.warn(
                "Don't use ConvNormActivation directly, please use Conv2dNormActivation and Conv3dNormActivation instead."
            )


In [53]:
class Conv2dNormActivation(ConvNormActivation):
    """
    Configurable block used for Convolution2d-Normalization-Activation blocks.

    Args:
        in_channels (int): Number of channels in the input image
        out_channels (int): Number of channels produced by the Convolution-Normalization-Activation block
        kernel_size: (int, optional): Size of the convolving kernel. Default: 3
        stride (int, optional): Stride of the convolution. Default: 1
        padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will be calculated as ``padding = (kernel_size - 1) // 2 * dilation``
        groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
        norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer won't be used. Default: ``torch.nn.BatchNorm2d``
        activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU``
        dilation (int): Spacing between kernel elements. Default: 1
        inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
        bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.

    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, Tuple[int, int]] = 3,
        stride: Union[int, Tuple[int, int]] = 1,
        padding: Optional[Union[int, Tuple[int, int], str]] = None,
        groups: int = 1,
        norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
        activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
        dilation: Union[int, Tuple[int, int]] = 1,
        inplace: Optional[bool] = True,
        bias: Optional[bool] = None,
    ) -> None:

        super().__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            groups,
            norm_layer,
            activation_layer,
            dilation,
            inplace,
            bias,
            torch.nn.Conv2d,
        )

In [54]:
class MLP(torch.nn.Sequential):
    """This block implements the multi-layer perceptron (MLP) module.

    Args:
        in_channels (int): Number of channels of the input
        hidden_channels (List[int]): List of the hidden channel dimensions
        norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the linear layer. If ``None`` this layer won't be used. Default: ``None``
        activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the linear layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU``
        inplace (bool, optional): Parameter for the activation layer, which can optionally do the operation in-place.
            Default is ``None``, which uses the respective default values of the ``activation_layer`` and Dropout layer.
        bias (bool): Whether to use bias in the linear layer. Default ``True``
        dropout (float): The probability for the dropout layer. Default: 0.0
    """

    def __init__(
        self,
        in_channels: int,
        hidden_channels: List[int],
        norm_layer: Optional[Callable[..., torch.nn.Module]] = None,
        activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
        inplace: Optional[bool] = None,
        bias: bool = True,
        dropout: float = 0.0,
    ):
        # The addition of `norm_layer` is inspired from the implementation of TorchMultimodal:
        # https://github.com/facebookresearch/multimodal/blob/5dec8a/torchmultimodal/modules/layers/mlp.py
        params = {} if inplace is None else {"inplace": inplace}

        layers = []
        in_dim = in_channels
        for hidden_dim in hidden_channels[:-1]:
            layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias))
            if norm_layer is not None:
                layers.append(norm_layer(hidden_dim))
            layers.append(activation_layer(**params))
            layers.append(torch.nn.Dropout(dropout, **params))
            in_dim = hidden_dim

        layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias))
        layers.append(torch.nn.Dropout(dropout, **params))

        super().__init__(*layers)
        log_api_usage_once(self)

In [55]:
class MLPBlock(MLP):
    """Transformer MLP block."""

    _version = 2

    def __init__(self, in_dim: int, mlp_dim: int, dropout: float):
        super().__init__(in_dim, [mlp_dim, in_dim], activation_layer=nn.GELU, inplace=None, dropout=dropout)

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.normal_(m.bias, std=1e-6)

    def _load_from_state_dict(
        self,
        state_dict,
        prefix,
        local_metadata,
        strict,
        missing_keys,
        unexpected_keys,
        error_msgs,
    ):
        version = local_metadata.get("version", None)

        if version is None or version < 2:
            # Replacing legacy MLPBlock with MLP. See https://github.com/pytorch/vision/pull/6053
            for i in range(2):
                for type in ["weight", "bias"]:
                    old_key = f"{prefix}linear_{i+1}.{type}"
                    new_key = f"{prefix}{3*i}.{type}"
                    if old_key in state_dict:
                        state_dict[new_key] = state_dict.pop(old_key)

        super()._load_from_state_dict(
            state_dict,
            prefix,
            local_metadata,
            strict,
            missing_keys,
            unexpected_keys,
            error_msgs,
        )

In [56]:
class EncoderBlock(nn.Module):
    """Transformer encoder block."""

    def __init__(
        self,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float,
        attention_dropout: float,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
    ):
        super().__init__()
        self.num_heads = num_heads

        # Attention block
        self.ln_1 = norm_layer(hidden_dim)
        self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True)
        self.dropout = nn.Dropout(dropout)

        # MLP block
        self.ln_2 = norm_layer(hidden_dim)
        self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)

    def forward(self, input: torch.Tensor):
        torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
        x = self.ln_1(input)
        x, _ = self.self_attention(x, x, x, need_weights=False)
        x = self.dropout(x)
        x = x + input

        y = self.ln_2(x)
        y = self.mlp(y)
        return x + y

In [57]:
class Encoder(nn.Module):
    """Transformer Model Encoder for sequence to sequence translation."""

    def __init__(
        self,
        seq_length: int,
        num_layers: int,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float,
        attention_dropout: float,
        num_cls_tokens: int,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6)
    ):
        super().__init__()
        self.num_cls_tokens = num_cls_tokens
        self.seq_length = seq_length
        # Note that batch_size is on the first dim because
        # we have batch_first=True in nn.MultiAttention() by default

        self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02))  # from BERT

        self.dropout = nn.Dropout(dropout)
        layers: OrderedDict[str, nn.Module] = OrderedDict()
        for i in range(num_layers):
            layers[f"encoder_layer_{i}"] = EncoderBlock(
                num_heads,
                hidden_dim,
                mlp_dim,
                dropout,
                attention_dropout,
                norm_layer,
            )
        self.layers = nn.Sequential(layers)
        self.ln = norm_layer(hidden_dim)

    def forward(self,
                input: torch.Tensor,
                idx_keep: Optional[torch.Tensor] = None
               ):

        torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")

        input = input + self._interpolate_pos_encoding(input,self.num_cls_tokens)
        if idx_keep is not None:
            input = get_at_index(input, idx_keep)
        return self.ln(self.layers(self.dropout(input)))

    def _interpolate_pos_encoding(self,
                                  input: torch.Tensor,
                                  num_cls_tokens: int = 3):
        """Returns the interpolated positional embedding for the given input.

        This function interpolates self.pos_embedding for all tokens in the input,
        ignoring the class token. This allows encoding variable sized images.

        Args:
            input:
               Input tensor with shape (batch_size, num_sequences).
            num_cls_tokens:
               number of classification tokens prepended to the
        """

        npatch = input.shape[1] - num_cls_tokens
        N = self.pos_embedding.shape[1] - num_cls_tokens
        diff = num_cls_tokens -1
        if npatch == N:
            return self.pos_embedding

        else:
            npatch += diff
            class_emb = self.pos_embedding[:, 1]
            pos_embedding = self.pos_embedding[:, 1+diff:]
        dim = input.shape[-1]
        pos_embedding = nn.functional.interpolate(
            pos_embedding.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
            scale_factor=math.sqrt(npatch / N),
            mode="bicubic",
        )
        pos_embedding = pos_embedding.permute(0, 2, 3, 1).view(1, -1, dim)
        return torch.cat((class_emb.unsqueeze(0), pos_embedding), dim=1)

In [58]:
class VisionTransformer(nn.Module):
    """Vision Transformer as per https://arxiv.org/abs/2010.11929."""

    def __init__(
        self,
        image_size: int = 224,
        patch_size: int = 16,
        num_layers: int = 12,
        num_heads: int = 4,
        hidden_dim: int = 512,
        mlp_dim: int = 512,
        dropout: float = 0.0,
        attention_dropout: float = 0.0,
        num_classes: int = 1000,
        num_cls_tokens: int = 3,
        representation_size: Optional[int] = None,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
        conv_stem_configs: Optional[List[ConvStemConfig]] = None,
    ):
        super().__init__()
        log_api_usage_once(self)
        torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!")
        self.image_size = image_size
        self.patch_size = patch_size
        self.hidden_dim = hidden_dim
        self.mlp_dim = mlp_dim
        self.attention_dropout = attention_dropout
        self.dropout = dropout
        self.num_classes = num_classes
        self.representation_size = representation_size
        self.norm_layer = norm_layer
        self.num_cls_tokens = num_cls_tokens

        if conv_stem_configs is not None:
            # As per https://arxiv.org/abs/2106.14881
            seq_proj = nn.Sequential()
            prev_channels = 3
            for i, conv_stem_layer_config in enumerate(conv_stem_configs):
                seq_proj.add_module(
                    f"conv_bn_relu_{i}",
                    Conv2dNormActivation(
                        in_channels=prev_channels,
                        out_channels=conv_stem_layer_config.out_channels,
                        kernel_size=conv_stem_layer_config.kernel_size,
                        stride=conv_stem_layer_config.stride,
                        norm_layer=conv_stem_layer_config.norm_layer,
                        activation_layer=conv_stem_layer_config.activation_layer,
                    ),
                )
                prev_channels = conv_stem_layer_config.out_channels
            seq_proj.add_module(
                "conv_last", nn.Conv2d(in_channels=prev_channels, out_channels=hidden_dim, kernel_size=1)
            )
            self.conv_proj: nn.Module = seq_proj
        else:
            self.conv_proj = nn.Conv2d(
                in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size
            )

        self.seq_length = (image_size // patch_size) ** 2
        # Add a class token
        self.class_token = nn.Parameter(torch.zeros(1, self.num_cls_tokens, hidden_dim))
        self.seq_length += self.num_cls_tokens

        self.encoder = Encoder(
            self.seq_length,
            num_layers,
            num_heads,
            hidden_dim,
            mlp_dim,
            dropout,
            attention_dropout,
            num_cls_tokens,
            norm_layer)
        #self.seq_length = seq_length #+ self.num_cls_tokens

        heads_layers: OrderedDict[str, nn.Module] = OrderedDict()
        if representation_size is None:
            heads_layers["head"] = nn.Linear(hidden_dim, num_classes)
        else:
            heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size)
            heads_layers["act"] = nn.Tanh()
            heads_layers["head"] = nn.Linear(representation_size, num_classes)

        self.heads = nn.Sequential(heads_layers)

        if isinstance(self.conv_proj, nn.Conv2d):
            # Init the patchify stem
            fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1]
            nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in))
            if self.conv_proj.bias is not None:
                nn.init.zeros_(self.conv_proj.bias)
        elif self.conv_proj.conv_last is not None and isinstance(self.conv_proj.conv_last, nn.Conv2d):
            # Init the last 1x1 conv of the conv stem
            nn.init.normal_(
                self.conv_proj.conv_last.weight, mean=0.0, std=math.sqrt(2.0 / self.conv_proj.conv_last.out_channels)
            )
            if self.conv_proj.conv_last.bias is not None:
                nn.init.zeros_(self.conv_proj.conv_last.bias)

        if hasattr(self.heads, "pre_logits") and isinstance(self.heads.pre_logits, nn.Linear):
            fan_in = self.heads.pre_logits.in_features
            nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in))
            nn.init.zeros_(self.heads.pre_logits.bias)

        if isinstance(self.heads.head, nn.Linear):
            nn.init.zeros_(self.heads.head.weight)
            nn.init.zeros_(self.heads.head.bias)

    def _process_input(self, x: torch.Tensor) -> torch.Tensor:
        n, c, h, w = x.shape
        p = self.patch_size
        #torch._assert(h == self.image_size, f"Wrong image height! Expected {self.image_size} but got {h}!")
        #torch._assert(w == self.image_size, f"Wrong image width! Expected {self.image_size} but got {w}!")
        n_h = h // p
        n_w = w // p

        # (n, c, h, w) -> (n, hidden_dim, n_h, n_w)
        x = self.conv_proj(x)
        # (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w))
        x = x.reshape(n, self.hidden_dim, n_h * n_w)

        # (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim)
        # The self attention layer expects inputs in the format (N, S, E)
        # where S is the source sequence length, N is the batch size, E is the
        # embedding dimension
        x = x.permute(0, 2, 1)

        return x

    def forward(self, x: torch.Tensor,
                branch: str = 'target',
                idx_keep: Optional[torch.Tensor] = None):
        # Reshape and permute the input tensor
        x = self._process_input(x)
        n = x.shape[0]

        # Expand the class token to the full batch
        if branch == 'target':
            batch_class_token = self.class_token.expand(n, -1, -1)
            x = torch.cat([batch_class_token, x], dim=1)
            x = self.encoder(x)
            return x[:,0:self.num_cls_tokens]
        elif branch == 'anchor':
            batch_class_token = self.class_token.expand(n, -1, -1)
            x = torch.cat([batch_class_token, x], dim=1)
            x = torch.cat((x[:,:1],x[:,self.num_cls_tokens:]),dim = 1)
            x = self.encoder(x,idx_keep=idx_keep)
            return x[:,0]

        # Classifier "token" as used by standard language architectures
        # else:
        #x = x[:, 0]
        #x = self.heads(x)

        # return x
# stemconfig = [ConvStemConfig(out_channels = 64, kernel_size = 3 , stride = 2) for i in range(4)]

In [59]:
stemconfig = [ConvStemConfig(out_channels = 64, kernel_size = 3 , stride = 2) for i in range(4)]
x = torch.randn(2,3,224,224)
VisionTransformer()(x,'anchor')

tensor([[ 0.6665, -0.2615, -1.7997,  ...,  0.1416,  1.7796,  0.8774],
        [-0.8177, -0.4620, -1.3975,  ..., -0.3456,  1.2394,  0.5589]],
       grad_fn=<SelectBackward0>)

# MSN implementation

In [13]:
# ! pip install lightly

In [14]:
import copy
# from lightly.models import utils
from typing import Tuple

In [15]:
def count_parameters(model:nn.Module) -> int:
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [16]:
def deactivate_requires_grad(model: nn.Module):
    """Deactivates the requires_grad flag for all parameters of a model.

    This has the same effect as permanently executing the model within a `torch.no_grad()`
    context. Use this method to disable gradient computation and therefore
    training for a model.

    Examples:
        >>> backbone = resnet18()
        >>> deactivate_requires_grad(backbone)
    """
    for param in model.parameters():
        param.requires_grad = False

In [17]:
@torch.no_grad()
def update_momentum(model: nn.Module, model_ema: nn.Module, m: float):
    """Updates parameters of `model_ema` with Exponential Moving Average of `model`

    Momentum encoders are a crucial component fo models such as MoCo or BYOL.

    Examples:
        >>> backbone = resnet18()
        >>> projection_head = MoCoProjectionHead()
        >>> backbone_momentum = copy.deepcopy(moco)
        >>> projection_head_momentum = copy.deepcopy(projection_head)
        >>>
        >>> # update momentum
        >>> update_momentum(moco, moco_momentum, m=0.999)
        >>> update_momentum(projection_head, projection_head_momentum, m=0.999)
    """
    for model_ema, model in zip(model_ema.parameters(), model.parameters()):
        model_ema.data = model_ema.data * m + model.data * (1.0 - m)

In [18]:
def random_token_mask(size: Tuple[int, int],
                      mask_ratio: float = 0.6,
                      mask_class_token: bool = False,
                     ) -> torch.Tensor:
    """Creates random token masks.

    Args:
        size:
            Size of the token batch for which to generate masks.
            Should be (batch_size, sequence_length).
        mask_ratio:
            Percentage of tokens to mask.
        mask_class_token:
            If False the class token is never masked. If True the class token
            might be masked.
        device:
            Device on which to create the index masks.

    Returns:
        A (index_keep, index_mask) tuple where each index is a tensor.
        index_keep contains the indices of the unmasked tokens and has shape
        (batch_size, num_keep). index_mask contains the indices of the masked
        tokens and has shape (batch_size, sequence_length - num_keep).
        num_keep is equal to sequence_length * (1- mask_ratio).

    """
    batch_size, sequence_length = size
    num_keep = int(sequence_length * (1 - mask_ratio))

    noise = torch.rand(batch_size, sequence_length)
    if not mask_class_token and sequence_length > 0:
        # make sure that class token is not masked
        noise[:, 0] = -1
        num_keep = max(1, num_keep)

    # get indices of tokens to keep
    indices = torch.argsort(noise, dim=1)
    idx_keep = indices[:, :num_keep]
    idx_mask = indices[:, num_keep:]

    return idx_keep, idx_mask

In [19]:
class DenseBlock(nn.Module):
    def __init__(self,
                 in_features: int = 768,
                 out_features: int = 2048,
               bias: bool = False,
              ) -> None:

        super().__init__()

        self.dense_block = nn.Sequential(nn.Linear(in_features= in_features,
                                                   out_features= out_features,
                                               bias=bias),
                                         nn.LayerNorm(normalized_shape= out_features),
                                         nn.GELU()
                                        )
    def forward(self,
                x: torch.Tensor
                ) -> torch.tensor:
        x = self.dense_block(x)
        return x

In [20]:
y = torch.rand(2,2,768)
DenseBlock(768,768*2)(y).shape

torch.Size([2, 2, 1536])

In [21]:
class ProjectionHead(nn.Module):
    def __init__(self,
                 in_features: int = 512,
                 hidden_features: int = 2048,
                 out_features: int = 512,
                 bias : bool = False
                 ) -> None:

        super().__init__()

        self.projection_head = nn.Sequential(DenseBlock(in_features= in_features,
                                                    out_features= hidden_features,
                                                    bias= bias
                                                    ),
                                         DenseBlock(in_features= hidden_features,
                                                    out_features= hidden_features,
                                                    bias= bias
                                                    ),
                                         nn.Linear(in_features= hidden_features,
                                                   out_features= out_features,
                                                   bias= bias
                                                   )
                                         )

    def forward(self,
                x: torch.Tensor
                ) -> torch.Tensor:
        x = self.projection_head(x)
        return x

In [22]:
x = torch.randn(2,3,224,224)

In [23]:
backbone = VisionTransformer()

In [24]:
class ChexMSN(nn.Module):
    def __init__(self,
               backbone: nn.Module,
               masking_ratio: float = 0.15,
               ema_p: float = 0.996,
               focal: bool = True
              ) -> None:
        super().__init__()

        self.masking_ratio = masking_ratio
        self.ema_p = ema_p
        self.focal = focal

        self.anchor_backbone = backbone
        self.anchor_projection_head = ProjectionHead()

        self.target_backbone = copy.deepcopy(self.anchor_backbone)
        self.target_projection_head = copy.deepcopy(self.anchor_projection_head)

        deactivate_requires_grad(self.target_backbone)
        deactivate_requires_grad(self.target_projection_head)



    def forward(self,
                views:list[torch.tensor],
                focal: bool = True
                ) -> tuple[torch.Tensor]:

        update_momentum(model= self.anchor_backbone,
                        model_ema= self.target_backbone,
                        m = self.ema_p,
                       )
        update_momentum(model= self.anchor_projection_head,
                        model_ema= self.target_projection_head,
                        m = self.ema_p,
                       )
        projections = self._forward_all(batch=views,focal=self.focal)
        

        return projections

    def _target_forward(self,
                        view: torch.tensor
                        ) -> torch.Tensor:

        target_encodings = self.target_backbone(x= view,
                                                branch='target'
                                                )
        target_projections = self.target_projection_head(x= target_encodings)

        return target_projections


    def _anchor_forward(self,
                        view: torch.tensor
                        ) -> torch.Tensor:

        batch_size, _, _, width = view.shape
        seq_length = (width // self.anchor_backbone.patch_size) ** 2
        idx_keep, idx_mask = random_token_mask(size= (view.shape[0],seq_length),
                                               mask_ratio= self.masking_ratio
                                          )

        anchor_encodings = self.anchor_backbone(x= view,
                                                branch= 'anchor',
                                                idx_keep= idx_keep)
        anchor_projections = self.anchor_projection_head(x= anchor_encodings)

        return anchor_projections
    
    
    def _forward_all(self,
                     batch: list,
                     focal: bool = True
                     ) -> torch.tensor:
        
        target_view = batch[0][0]
        anchor_view_sim = batch[0][1]
        focal_views_sim = torch.concat(batch[0][2:],dim=0)
        anchor_view_age = batch[1][1]
        focal_views_age = torch.concat(batch[1][2:],dim=0)
        anchor_view_gender = batch[2][1]
        focal_views_gender = torch.concat(batch[2][2:],dim=0)
        
        
        target_projections = self._target_forward(target_view)
        
        anchor_projections_sim = self._anchor_forward(anchor_view_sim)
        if focal:
            anchor_focal_projections_sim = self._anchor_forward(focal_views_sim)
            similarity_projections = self._arrange_tokens(anchor_projections_sim,
                                                          anchor_focal_projections_sim,
                                                          num_focal = 10)
                                                          
        
        anchor_projections_age = self._anchor_forward(anchor_view_age)
        if focal:
            anchor_focal_projections_age =  self._anchor_forward(focal_views_age)
            age_projections = self._arrange_tokens(anchor_projections_age,
                                                   anchor_focal_projections_age,
                                                   num_focal = 10)
                                                    
        
        
        anchor_projections_gender = self._anchor_forward(anchor_view_gender)
        if focal:
            anchor_focal_projections_gender = self._anchor_forward(focal_views_gender)  
            gender_projections = self._arrange_tokens(anchor_projections_gender,
                                                      anchor_focal_projections_gender,
                                                      num_focal = 10)
                                                          
        if focal:
            anchor_projections = torch.stack((similarity_projections,
                                              age_projections,
                                              gender_projections
                                              ),
                                         dim= 0)
        else:
            anchor_projections = torch.stack((anchor_projections_sim,
                                              anchor_projections_age,
                                              anchor_projections_gender
                                              ),
                                             dim= 1)       

        return (anchor_projections,
                target_projections)
    
    def _arrange_tokens(self,
                        tensor1: torch.tensor,
                        tensor2:torch.tensor,
                        num_focal: int = 10
                        ) ->torch.tensor:

        a = torch.stack(torch.split(tensor1,1),0)
        b = torch.stack(torch.split(tensor2,num_focal),0)
        c = torch.cat((a.expand(-1,num_focal,-1),b),dim=1)[:,num_focal-1:]
        arranged_tokens = torch.cat(c.split(1),1).squeeze(0)
        return arranged_tokens

In [25]:
# x = []
# x = [torch.randn(10,3,224,224),torch.randn(10,3,224,224)]
# y = [torch.randn(10,3,96,96)for i in range(10)]

# x.extend(y)

next(iter(datalodaer))[0]

NameError: name 'datalodaer' is not defined

In [26]:
model = ChexMSN(backbone,focal=True)
print(len(next(iter(datalodaer))[0]))
model(next(iter(datalodaer)))

NameError: name 'datalodaer' is not defined

In [27]:
def arrange_tokens(tensor1: torch.tensor,
                   tensor2:torch.tensor,
                   num_focal: int = 10
                   ) ->torch.tensor:

    a = torch.stack(torch.split(tensor1,1),0)
    b = torch.stack(torch.split(tensor2,num_focal),0)
    c = torch.cat((a.expand(-1,num_focal,-1),b),dim=1)[:,num_focal-1:]
    
    return torch.cat(c.split(1),1).squeeze(0)

In [28]:
import math
import warnings
from typing import Union

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

In [29]:
def prototype_probabilities(queries: torch.tensor,
                            prototypes: torch.tensor,
                            temperature: float,
                            ) -> torch.tensor:
    """Returns probability for each query to belong to each prototype.

    Args:
        queries:
            Tensor with shape (batch_size, dim), projection head output
        prototypes:
            Tensor with shape (num_prototypes, dim)
        temperature:
            Inverse scaling factor for the similarity.

    Returns:
        Probability tensor with shape (batch_size, num_prototypes) which sums to 1 along
        the num_prototypes dimension.

    """                           
    return F.softmax(torch.matmul(queries, prototypes.T) / temperature, dim=1)

In [30]:
def sharpen(probabilities: torch.tensor, 
            temperature: float
           ) -> torch.tensor:
    """Sharpens the probabilities with the given temperature.

    Args:
        probabilities:
            Tensor with shape (batch_size, dim)
        temperature:
            Temperature in (0, 1]. Lower temperature results in stronger sharpening (
            output probabilities are less uniform).
    Returns:
        Probabilities tensor with shape (batch_size, num_prototypes).

    """
    probabilities = probabilities ** (1.0 / temperature)
    probabilities /= torch.sum(probabilities, dim=1, keepdim=True)
    return probabilities

In [31]:
@torch.no_grad()
def sinkhorn(probabilities: torch.tensor,
             iterations: int = 3,
             gather_distributed: bool = False,
            ) -> torch.tensor:
    """Runs sinkhorn normalization on the probabilities as described in [0].

    Code inspired by [1].

    - [0]: Masked Siamese Networks, 2022, https://arxiv.org/abs/2204.07141
    - [1]: https://github.com/facebookresearch/msn

    Args:
        probabilities:
            Probabilities tensor with shape (batch_size, num_prototypes).
        iterations:
            Number of iterations of the sinkhorn algorithms. Set to 0 to disable.
        gather_distributed:
            If True then features from all gpus are gathered during normalization.
    Returns:
        A normalized probabilities tensor.

    """
    if iterations <= 0:
        return probabilities


    num_targets, num_prototypes = probabilities.shape
    probabilities = probabilities.T
    sum_probabilities = torch.sum(probabilities)

    probabilities = probabilities / sum_probabilities

    for _ in range(iterations):
        # normalize rows
        row_sum = torch.sum(probabilities, dim=1, keepdim=True)

        probabilities /= row_sum
        probabilities /= num_prototypes

        # normalize columns
        probabilities /= torch.sum(probabilities, dim=0, keepdim=True)
        probabilities /= num_targets

    probabilities *= num_targets
    return probabilities.T

In [32]:
def regularization_loss(mean_anchor_probs: torch.tensor
                       ) -> torch.tensor:
    """Calculates mean entropy regularization loss."""
    loss = -torch.sum(torch.log(mean_anchor_probs ** (-mean_anchor_probs)))
    loss += math.log(float(len(mean_anchor_probs)))
    return loss

In [33]:
prototypes = nn.ModuleList([nn.Linear(512,1024),
                           nn.Linear(512,1024),
                           nn.Linear(512,1024)]
                          )


In [34]:
class MSNLoss(nn.Module):
    def __init__(self,
                 temperature: float = 0.1,
                 sinkhorn_iterations: int = 3,
                 similarity_weight: float = 1.0,
                 age_weight: float = 1.0,
                 gender_weight: float = 1.0,
                 regularization_weight: float = 1.0,
               ) -> None:

        super().__init__()
        
        self.temperature = temperature
        self.sinkhorn_iterations = sinkhorn_iterations
        self.similarity_weight = similarity_weight
        self.age_weight = age_weight
        self.gender_weight = gender_weight
        self.regularization_weight = regularization_weight
    
    
    def forward(self,
                anchors: torch.tensor,
                targets: torch.tensor,
                prototypes: torch.tensor,
                target_sharpen_temperature: float = 0.25,
                focal: bool = True
                      ) -> torch.tensor:
        
        similarity_loss = self.similarity_weight * self._forward_loss(anchors=anchors[0] if focal else anchors[:,0],
                                                                      targets=targets[:,0],
                                                                      prototypes=prototypes[0].weight
                                                                     )
        age_loss = self.age_weight * self._forward_loss(anchors=anchors[1] if focal else anchors[:,1],
                                                           targets=targets[:,1],
                                                           prototypes=prototypes[1].weight
                                                          )
        gender_loss = self.gender_weight * self._forward_loss(anchors=anchors[2] if focal else anchors[:,2],
                                                              targets=targets[:,2],
                                                              prototypes=prototypes[2].weight
                                                             )
        
        loss = similarity_loss + gender_loss + gender_loss
        return loss
    
    def _forward_loss(self,
                anchors: torch.tensor,
                targets: torch.tensor,
                prototypes: torch.tensor,
                target_sharpen_temperature: float = 0.25,
               ) -> torch.tensor:

        num_views = anchors.shape[0] // targets.shape[0]
        anchors = F.normalize(anchors, dim=1)
        targets = F.normalize(targets, dim=1)
        prototypes = F.normalize(prototypes, dim=1)

        anchor_probs = prototype_probabilities(anchors, 
                                               prototypes, 
                                               temperature=self.temperature
                                              )

        with torch.no_grad():
            target_probs = prototype_probabilities(targets, 
                                                   prototypes, 
                                                   temperature=self.temperature
                                                   )
            target_probs = sharpen(target_probs, temperature=target_sharpen_temperature)
            if self.sinkhorn_iterations > 0:
                target_probs = sinkhorn(probabilities=target_probs,
                                        iterations=self.sinkhorn_iterations,
                                        )
            target_probs = torch.repeat_interleave(target_probs, repeats=num_views, dim=0)

        loss = torch.mean(torch.sum(torch.log(anchor_probs ** (-target_probs)), dim=1))

        # regularization loss
        if self.regularization_weight > 0:
            mean_anchor_probs = torch.mean(anchor_probs, dim=0)
            reg_loss = regularization_loss(mean_anchor_probs=mean_anchor_probs)
            loss += self.regularization_weight * reg_loss
            
        return loss

In [35]:
model = ChexMSN(backbone,focal=True)
a,b = model(next(iter(datalodaer)))

NameError: name 'datalodaer' is not defined

In [36]:
criterion = MSNLoss(temperature=.5)
criterion(a, b,prototypes,focal=True)

NameError: name 'a' is not defined

In [37]:
from typing import Dict, List, Optional, Tuple, Union

import torchvision.transforms as T
from PIL.Image import Image
from torch import Tensor


from lightly.transforms.multi_view_transform import MultiViewTransform
from lightly.transforms.utils import IMAGENET_NORMALIZE


class MSNTransform(MultiViewTransform):
    """Implements the transformations for MSN [0].

    Input to this transform:
        PIL Image or Tensor.

    Output of this transform:
        List of Tensor of length 2 * random_views + focal_views. (12 by default)

    Applies the following augmentations by default:
        - Random resized crop
        - Random horizontal flip
        - ImageNet normalization

    Generates a set of random and focal views for each input image. The generated output
    is (views, target, filenames) where views is list with the following entries:
    [random_views_0, random_views_1, ..., focal_views_0, focal_views_1, ...].

    - [0]: Masked Siamese Networks, 2022: https://arxiv.org/abs/2204.07141

    Attributes:
        random_size:
            Size of the random image views in pixels.
        focal_size:
            Size of the focal image views in pixels.
        random_views:
            Number of random views to generate.
        focal_views:
            Number of focal views to generate.
        random_crop_scale:
            Minimum and maximum size of the randomized crops for the relative to random_size.
        focal_crop_scale:
            Minimum and maximum size of the randomized crops relative to focal_size.
        hf_prob:
            Probability that horizontal flip is applied.
        vf_prob:
            Probability that vertical flip is applied.
        normalize:
            Dictionary with 'mean' and 'std' for torchvision.transforms.Normalize.
    """

    def __init__(
        self,
        random_size: int = 224,
        focal_size: int = 96,
        random_views: int = 2,
        focal_views: int = 10,
        affine_dgrees: int = 15,
        affine_scale: Tuple[float,float]= (.9, 1.1),
        affine_shear: int = 0,
        affine_translate: Tuple[float,float] = (0.1, 0.1),
        random_crop_scale: Tuple[float, float] = (0.3, 1.0),
        focal_crop_scale: Tuple[float, float] = (0.05, 0.3),
        hf_prob: float = 0.5,
        vf_prob: float = 0.0,
        normalize: Dict[str, List[float]] = IMAGENET_NORMALIZE,
    ):
        random_view_transform = MSNViewTransform(
            affine_dgrees=affine_dgrees,
            affine_scale=affine_scale,
            affine_shear=affine_shear,
            affine_translate=affine_translate,
            crop_size=random_size,
            crop_scale=random_crop_scale,
            hf_prob=hf_prob,
            vf_prob=vf_prob,
            normalize=normalize,
        )
        focal_view_transform = MSNViewTransform(
            affine_dgrees=affine_dgrees,
            affine_scale=affine_scale,
            affine_shear=affine_shear,
            affine_translate=affine_translate,
            crop_size=focal_size,
            crop_scale=focal_crop_scale,
            hf_prob=hf_prob,
            vf_prob=vf_prob,
            normalize=normalize,
        )
        transforms = [random_view_transform] * random_views
        transforms += [focal_view_transform] * focal_views
        super().__init__(transforms=transforms)


class MSNViewTransform:
    def __init__(
        self,
        affine_dgrees: int = 15,
        affine_scale: Tuple[float,float]= (.9, 1.1),
        affine_shear: int = 0,
        affine_translate: Tuple[float,float] = (0.1, 0.1),
        crop_size: int = 224,
        crop_scale: Tuple[float, float] = (0.3, 1.0),
        hf_prob: float = 0.5,
        vf_prob: float = 0.0,
        normalize: Dict[str, List[float]] = IMAGENET_NORMALIZE,
    ):

        transform = [
            T.RandomAffine(degrees=affine_dgrees, 
                           scale=affine_scale, 
                           shear=affine_shear, 
                           translate=affine_translate),
            T.RandomResizedCrop(size=crop_size, scale=crop_scale),
            T.RandomHorizontalFlip(p=hf_prob),
            T.RandomVerticalFlip(p=vf_prob),
            T.ToTensor(),
            T.Normalize(mean=normalize["mean"], std=normalize["std"]),
        ]

        self.transform = T.Compose(transform)

    def __call__(self, image: Union[Tensor, Image]) -> Tensor:
        """
        Applies the transforms to the input image.

        Args:
            image:
                The input image to apply the transforms to.

        Returns:
            The transformed image.

        """
        transformed: Tensor = self.transform(image)
        return transformed

In [38]:
import os
import torch
import torch.nn as nn 
import pytorch_lightning as pl

import random
from collections import Counter
from typing import Tuple, Optional
from torch import Tensor
from PIL import Image
from torch.utils.data import Dataset

import os
import torch
import pandas as pd
import matplotlib.pyplot as plt
import torch.nn as nn
import torchvision.transforms as T
from tqdm import tqdm
from torch import Tensor
from copy import deepcopy
from typing import List, Tuple, Dict
from torch.utils.data import DataLoader

from lightly.transforms import MSNTransform
from lightly.transforms.utils import IMAGENET_NORMALIZE

class ChexMSNDataset(Dataset):
    def __init__(self, 
                 data_dir: str,
                 transforms: nn.Module,
                 same = True
                 ) -> None:
      
        self.meta = pd.read_csv(data_dir)
        self.all_images = list(self.meta.path)
        self.transform = transforms
        self.same = same
        
    def __len__(self
                ) -> int:
        return len(self.all_images)
    
    def __getitem__(self,
                    index: int
                    ) -> Tuple[torch.Tensor]:

        
        target_path = self.all_images[index]
        image_id = target_path.split('/')[-1][:-4]
        img_age_path, img_gender_path = self._retrieve_anchors(image_id=image_id,
                                                               meta = self.meta,
                                                               same=self.same)

        img_target = Image.open(fp=target_path).convert('RGB')
        img_target = self.transform(img_target)
        
        img_age = Image.open(fp=img_age_path).convert('RGB')
        img_age = self.transform(img_age)

        img_gender = Image.open(fp=img_gender_path).convert('RGB')
        img_gender = self.transform(img_gender)

        return (img_target,img_age,img_gender)
    
    
    def _retrieve_anchors(self,
                          image_id: str,
                          meta: pd.DataFrame,
                          same: bool = False) -> Tuple[str]:
        record = meta[meta.dicom_id == image_id]
    
        subject_id = list(record.subject_id)[0]
        age_groub =list(record.ageR5)[0] 
        gender = list(record.gender)[0]
    
        group = meta[meta.ageR5 == age_groub]
    
        if same:
            candidate_anchors = group[group.gender == gender]
            candidate_anchors = candidate_anchors[candidate_anchors.subject_id != subject_id]
            images= list(candidate_anchors.path)
            sampled_images = random.sample(images,k=2)
            image_age, image_gender = sampled_images[0],sampled_images[1]
            return image_age, image_gender
        else:
            candidate_anchors = group
            candidate_anchors = candidate_anchors[candidate_anchors.subject_id != subject_id]
            images= list(candidate_anchors.path)
            image_age = random.sample(images,k=1)[0]
            candidate_anchors = candidate_anchors[candidate_anchors.gender == gender]
            images= list(candidate_anchors.path)
            image_gender = random.sample(images,k=1)[0]
            return image_age, image_gender
        

In [44]:
transforms = MSNTransform()

In [50]:
dataset = ChexMSNDataset(data_dir='../data/meta.csv',transforms= transforms)
datalodaer = DataLoader(dataset=dataset,batch_size=2,num_workers=8,pin_memory=True)

In [51]:
next(iter(datalodaer))

[[tensor([[[[-1.6213, -1.6213, -1.6213,  ..., -1.6213, -1.6213, -1.6213],
            [-1.6213, -1.6213, -1.6213,  ..., -1.6213, -1.6213, -1.6213],
            [-1.6213, -1.6213, -1.6213,  ..., -1.6042, -1.6213, -1.6213],
            ...,
            [-1.6213, -1.6213, -1.6213,  ..., -1.6213, -1.6213, -1.6213],
            [-1.6213, -1.6213, -1.6213,  ..., -1.6213, -1.6213, -1.6213],
            [-1.6213, -1.6213, -1.6213,  ..., -1.6213, -1.6213, -1.6213]],
  
           [[-1.5280, -1.5280, -1.5280,  ..., -1.5280, -1.5280, -1.5280],
            [-1.5280, -1.5280, -1.5280,  ..., -1.5280, -1.5280, -1.5280],
            [-1.5280, -1.5280, -1.5280,  ..., -1.5105, -1.5280, -1.5280],
            ...,
            [-1.5280, -1.5280, -1.5280,  ..., -1.5280, -1.5280, -1.5280],
            [-1.5280, -1.5280, -1.5280,  ..., -1.5280, -1.5280, -1.5280],
            [-1.5280, -1.5280, -1.5280,  ..., -1.5280, -1.5280, -1.5280]],
  
           [[-1.2990, -1.2990, -1.2990,  ..., -1.2990, -1.2990, -1.299

In [47]:
class ChexMSNModel(pl.LightningModule):
    def __init__(self,
                 model: nn.Module,
                 criterion: nn.Module,
                 num_prototypes: int = 2048,
                 learning_rate: float =  1e-3,
                 weight_decay: float = 0.0,
                 max_epochs: int = 50,
                 mask_ratio: float = 0.15,

                ) -> None:
        super().__init__()

        self.model = model
        self.prototypes = nn.ModuleList([nn.Linear(in_features=512,out_features=num_prototypes),
                                        nn.Linear(in_features=512,out_features=num_prototypes),
                                        nn.Linear(in_features=512,out_features=num_prototypes)])
                
    def training_step(self, 
                      batch: List[Tensor], 
                      batch_idx: int
                     ) -> float:
        

        anchors, target = model(batch) 
        loss = criterion(anchors,target,self.prototypes)
        self.log("train_loss", loss, on_epoch= True,on_step=False , logger=True)      
        return loss


    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), 
                                      lr=self.lr,
                                      weight_decay=self.weight_decay)
        
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer,
                                                               eta_min=0.00001,
                                                               T_max=self.max_epochs)
        return {'optimizer': optimizer,
               'lr_scheduler': scheduler
               }

    


In [48]:
chexmsn = ChexMSNModel(model=model,criterion=criterion)

In [49]:
chexmsn.training_step(next(iter(datalodaer)),0)

/home/sas10092/.conda/envs/chexmsn-env/lib/python3.9/site-packages/pytorch_lightning/core/module.py:420: You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet. This is most likely because the model hasn't been passed to the `Trainer`


tensor(22.8914, grad_fn=<AddBackward0>)

In [78]:
len(next(iter(datalodaer))[2])

12

In [75]:
model(next(iter(datalodaer))[0])

(tensor([[[ 0.0263,  0.2974, -0.1393,  ...,  0.5074,  0.1716, -0.1627],
          [ 0.1130,  0.0252,  0.1407,  ...,  0.4792, -0.2396, -0.1600],
          [ 0.1798,  0.1764,  0.1032,  ...,  0.1357, -0.6334,  0.0844],
          ...,
          [ 0.3475,  0.3314,  0.0842,  ...,  0.2358, -0.3051, -0.0203],
          [ 0.4003,  0.0806, -0.2598,  ...,  0.1931, -0.7471,  0.0807],
          [ 0.1596,  0.3951, -0.0781,  ...,  0.2155,  0.0606,  0.2141]],
 
         [[-0.0033,  0.2807, -0.0862,  ...,  0.4857,  0.1266, -0.1063],
          [ 0.0346,  0.0416, -0.0430,  ...,  0.3623, -0.0658,  0.0156],
          [ 0.2710,  0.1721,  0.0216,  ...,  0.0565, -0.7129,  0.1693],
          ...,
          [ 0.2687,  0.3021,  0.1107,  ...,  0.2366, -0.4238, -0.0339],
          [ 0.3268, -0.0498, -0.3285,  ...,  0.1033, -0.7735,  0.1527],
          [ 0.1666,  0.3889, -0.1307,  ...,  0.1600,  0.1442,  0.2036]],
 
         [[ 0.0555,  0.2447, -0.1671,  ...,  0.5027,  0.1655, -0.1685],
          [-0.0567,  0.0804,

In [None]:
import numpy as np
import copy
from tqdm import tqdm, trange
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader

from torchvision.transforms import ToTensor
from torchvision.datasets.mnist import MNIST

np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7fb045fec710>

In [None]:
class PatchEmbeddings(nn.Module):
    def __init__(self,
                 images_size: int =  224,
                 patch_size: int = 16,
                 num_channels: int = 3,
                 embedding_dim: int = 512
                ) -> None:
      super().__init__()
      """
        Role: convert the input image of size = (b,c,h,w) into patches sequence
              of size = (b,p,embedding_dim) where
              b = batch size
              c = number of channels
              h = height
              w = width

        Inputs:
            image_size - Dimensionality of the input image (width or height)
            patch_size - Dimensionality of the image patch (width or height)
            embedding_dim - Dimensionality of the patch embedding
            num_channels - Number of channels in the input images (usually 3)
      """
      self.image_size = images_size
      self.patch_size = patch_size
      self.num_channels = num_channels
      self.hidden_size = embedding_dim

      self.num_patches = (self.image_size // self.patch_size) ** 2

      self.projection_layer = nn.Conv2d(in_channels= self.num_channels,
                                        out_channels= self.hidden_size,
                                        kernel_size= self.patch_size,
                                        stride= self.patch_size
                                       )

    def forward(self,
                x: torch.Tensor
                ) -> torch.Tensor:
      x = self.projection_layer(x)
      x = x.flatten(2).transpose(1,2)
      return x

In [None]:
class InputEmbeddings(nn.Module):
    def __init__(self,
                 images_size: int =  224,
                 num_channels: int = 3,
                 patch_size: int = 16,
                 embedding_dim: int = 512,
                 num_cls_tokens: int = 3,
                 dropout_p: float = 0.0,
                 learnable_pos_encoding: bool = True
                ) -> None:
      super().__init__()
      """
        Role: convert the input image into patches as well as adding CLS token
        and positional encoding

        Inputs:
            image_size - Dimensionality of the input image (width or height)
            num_channels - Number of channels in the input images (usually 3)
            patch_size - Dimensionality of the image patch (width or height)
            embedding_dim - Dimensionality of the patch embedding
            num_cls_tokens - number of classification tokens to add to the sequence
            dropout_p -  percentage of applied dropout
            learnable_pos_encoding - specify typ of positional encoding (fixed or learnable)
      """
      self.image_size = images_size
      self.patch_size = patch_size
      self.num_channels = num_channels
      self.embedding_dim = embedding_dim
      self.num_cls_tokens = num_cls_tokens
      self.dropout_p = dropout_p
      self.learnable_pos_encoding = learnable_pos_encoding

      self.patch_embeddings = PatchEmbeddings(images_size= self.image_size,
                                              patch_size= self.patch_size,
                                              num_channels= self.num_channels,
                                              embedding_dim= self.embedding_dim
                                              )

      self.cls_token = nn.Parameter(torch.randn(size= (1, self.num_cls_tokens, self.embedding_dim)))

      num_patches = self.patch_embeddings.num_patches
      if self.learnable_pos_encoding:
        num_patches = self.patch_embeddings.num_patches
        self.positional_embeddings = nn.Parameter(torch.empty(1, num_patches , self.embedding_dim).normal_(std=0.02))

      else:
        self.positional_embeddings = self._get_positional_embeddings(num_patches=num_patches,
                                                                     embedding_dim=self.embedding_dim
                                                                    )

      self.dropout = nn.Dropout(self.dropout_p)

    def forward(self,
                x: torch.Tensor
                ) -> torch.Tensor:

      x = self.patch_embeddings(x)
      batch_size, _, _ = x.size()
      cls_token = self.cls_token.expand(batch_size,-1,-1)
      x = torch.cat([cls_token,x], dim=1)
      x = x + self.positional_embeddings
      x = self.dropout(x)
      return x


    # def _get_positional_embeddings(self,
    #                              num_patches: int,
    #                              embedding_dim: int
    #                              ) -> torch.Tensor:
    #   poaitional_embeddings  = torch.ones(num_patches + self.num_cls_tokens, embedding_dim)
    #   for i in range(num_patches + self.num_cls_tokens):
    #       for j in range(embedding_dim):
    #           poaitional_embeddings [i][j] = np.sin(i / (10000 ** (j / embedding_dim))) \
    #           if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / embedding_dim)))
    #   return poaitional_embeddings.unsqueeze(0)


In [None]:
class FeedForwardNet(nn.Module):
    def __init__(self,
               embedding_dim: int = 512,
               hidden_dim = 192,
               dropout_p: float = 0.0
               ) -> None:
      super().__init__()
      """
        Role: Feed forward netork of attention head

        Inputs:
            embedding_dim - Dimensionality of the patch embedding
            hidden_dim - Dimensionality of hidden layer in feed-forward network
            dropout_p -  percentage of applied dropout
            learnable_pos_encoding - specify typ of positional encoding (fixed or learnable)
      """
      self.feed_forward = nn.Sequential(nn.Linear(in_features= embedding_dim,
                                                  out_features= hidden_dim
                                                 ),
                                        nn.GELU(),
                                        nn.Linear(in_features=hidden_dim,
                                                  out_features=embedding_dim
                                                 ),
                                        nn.Dropout(dropout_p)
                                       )
    def forward(self,
              x: torch.tensor
              )-> torch.Tensor:
      x = self.feed_forward(x)
      return x

In [None]:
class AttentionBlock(nn.Module):
    def __init__(self,
                 embedding_dim: int = 512,
                 hidden_dim: int = 192,
                 num_heads: int = 4,
                 dropout_p=0.0,
                 pre_layer_norm: bool= True
                 ) -> None:
      super().__init__()

      """ Role: capture the interaction between
          Inputs:
              embedding_dim - Dimensionality of input and attention feature vectors
              hidden_dim - Dimensionality of hidden layer in feed-forward network
              num_heads - Number of heads to use in the Multi-Head Attention block
              dropout - Amount of dropout to apply in the feed-forward network
              pre_layer_norm - Specifies where to apply the layer norm (before or after)
      """
      self.pre_layer_norm = pre_layer_norm
      self.layer_norm_1 = nn.LayerNorm(normalized_shape= embedding_dim)

      self.attention_head = nn.MultiheadAttention(embed_dim= embedding_dim,
                                                  num_heads= num_heads,
                                                  dropout= dropout_p)

      self.layer_norm_2 = nn.LayerNorm(normalized_shape= embedding_dim)

      self.feed_forward = FeedForwardNet(embedding_dim= embedding_dim,
                                         hidden_dim= hidden_dim,
                                         dropout_p= dropout_p
                                        )

    #attention weights from here
    def forward(self,
                x: torch.tensor
                )-> torch.Tensor:

        if self.pre_layer_norm:
          normalization_out1 = self.layer_norm_1(x)
          attention_out = self.attention_head(normalization_out1, normalization_out1, normalization_out1)[0]
          addition_out1 = x + attention_out
          normalization_out2 = self.layer_norm_2(addition_out1)
          feed_forward_out = self.feed_forward(normalization_out2)
          layer_out = addition_out1 + feed_forward_out

        else:
          attention_out = self.attention_head(x,x,x)[0]
          addition_out1 = x + attention_out
          normalization_out = self.layer_norm_1(addition_out1)
          feed_forward_out = self.feed_forward(normalization_out)
          addition_out2 = normalization_out + feed_forward_out
          layer_out = self.layer_norm_2(addition_out2)

        return attention_out

In [None]:
class VisionTransformer(nn.Module):
    def __init__(self,
                 num_layers: int = 6,
                 num_heads: int = 4,
                 hidden_dim: int = 192,
                 pre_layer_norm: bool= True,
                 images_size: int =  224,
                 num_channels: int = 3,
                 patch_size: int = 16,
                 embedding_dim: int = 512,
                 num_cls_tokens: int = 3,
                 dropout_p: float = 0.0,
                 learnable_pos_encoding: bool = False
                ) -> None:
      super().__init__()


      self.image_size = images_size
      self.patch_size = patch_size
      self.num_channels = num_channels
      self.embedding_dim = embedding_dim
      self.num_cls_tokens = num_cls_tokens
      self.dropout_p = dropout_p
      self.learnable_pos_encoding = learnable_pos_encoding

      self.input_embeddings = PatchEmbeddings(images_size= self.image_size,
                                              num_channels= self.num_channels,
                                              patch_size= self.patch_size,
                                              embedding_dim= self.embedding_dim,
                                              )
                                              # num_cls_tokens= self.num_cls_tokens,
                                              # dropout_p=self.dropout_p,
                                              # learnable_pos_encoding=self.learnable_pos_encoding)



      self.num_layers = num_layers
      self.num_heads = num_heads
      self.hidden_dim = hidden_dim
      self.pre_layer_norm = pre_layer_norm

      self.attention_block = AttentionBlock(embedding_dim= self.embedding_dim,
                                            hidden_dim= self.hidden_dim,
                                            num_heads= self.num_heads,
                                            dropout_p=self.dropout_p,
                                            pre_layer_norm=self.pre_layer_norm)


      self.cls_token = nn.Parameter(torch.randn(size= (1, self.num_cls_tokens, self.embedding_dim)))
      num_patches = self.input_embeddings.num_patches
      self.positional_embeddings = nn.Parameter(torch.empty(1, num_patches + self.num_cls_tokens, self.embedding_dim).normal_(std=0.02))

      self.encoder = nn.Sequential(*[self.attention_block for layer in range(num_layers)])


    def forward(self,
                x: torch.tensor,
                branch: str = 'target'
                )-> torch.Tensor:

      if branch == 'target':
        x = self.input_embeddings(x)
        batch_size, _, _ = x.size()
        cls_token = self.cls_token.expand(batch_size,-1,-1)
        x = torch.cat([cls_token,x], dim=1)
        x = x + self.positional_embeddings
        x = self.encoder(x)
        return x#[:,0:self.num_cls_tokens]

      elif branch == 'anchor':
        x = self.input_embeddings(x)
        batch_size, _, _ = x.size()
        cls_token = self.cls_token.expand(batch_size,-1,-1)
        x = torch.cat([cls_token,x], dim=1)
        x = x + self.positional_embeddings
        x = torch.cat((x[:,:1],x[:,self.num_cls_tokens:]),dim = 1)
        x = self.encoder(x)
        return x#[:,0]

In [330]:
h = torch.tensor([[1,2,3],[4,5,6]])
h

tensor([[1, 2, 3],
        [4, 5, 6]])

In [492]:
h.repeat((5,1))

tensor([[1, 2, 3],
        [4, 5, 6],
        [1, 2, 3],
        [4, 5, 6],
        [1, 2, 3],
        [4, 5, 6],
        [1, 2, 3],
        [4, 5, 6],
        [1, 2, 3],
        [4, 5, 6]])

In [493]:
h.repeat(5,1)

tensor([[1, 2, 3],
        [4, 5, 6],
        [1, 2, 3],
        [4, 5, 6],
        [1, 2, 3],
        [4, 5, 6],
        [1, 2, 3],
        [4, 5, 6],
        [1, 2, 3],
        [4, 5, 6]])

In [494]:
torch.cat([h for _ in range(5)], dim=0)

tensor([[1, 2, 3],
        [4, 5, 6],
        [1, 2, 3],
        [4, 5, 6],
        [1, 2, 3],
        [4, 5, 6],
        [1, 2, 3],
        [4, 5, 6],
        [1, 2, 3],
        [4, 5, 6]])