In [77]:
import torch
import torch.nn as nn
import numpy as np
device = "cuda" if torch.cuda.is_available() else "cpu"

In [78]:
torch.__version__

'1.7.1'

In [94]:
import torch.nn.functional as F
from torch.nn import MultiheadAttention, Linear, Dropout, LayerNorm
from typing import Optional, Any, Union, Callable
from torch import Tensor

class TransformerEncoderLayer(nn.Module):
    r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
    This standard encoder layer is based on the paper "Attention Is All You Need".
    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
    in a different way during application.
    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of the intermediate layer, can be a string
            ("relu" or "gelu") or a unary callable. Default: relu
        layer_norm_eps: the eps value in layer normalization components (default=1e-5).
        batch_first: If ``True``, then the input and output tensors are provided
            as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
        norm_first: if ``True``, layer norm is done prior to attention and feedforward
            operations, respectivaly. Otherwise it's done after. Default: ``False`` (after).
    Examples::
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        >>> src = torch.rand(10, 32, 512)
        >>> out = encoder_layer(src)
    Alternatively, when ``batch_first`` is ``True``:
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
        >>> src = torch.rand(32, 10, 512)
        >>> out = encoder_layer(src)
    """
    __constants__ = ['batch_first', 'norm_first']

    def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0,
                 activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
                 layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
                 device=None, dtype=None, kdim=None, vdim=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(TransformerEncoderLayer, self).__init__()
        if kdim == None:
            kdim = d_model
        if vdim == None:
            vdim = d_model
        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, kdim=kdim, vdim=vdim)
        # Implementation of Feedforward model
        self.linear1 = Linear(d_model, dim_feedforward)
        self.dropout = Dropout(dropout)
        self.linear2 = Linear(dim_feedforward, d_model)

        self.norm_first = norm_first
        self.norm1 = LayerNorm(d_model, eps=layer_norm_eps)
        self.norm2 = LayerNorm(d_model, eps=layer_norm_eps)
        self.dropout1 = Dropout(dropout)
        self.dropout2 = Dropout(dropout)

        # Legacy string support for activation function.
        if isinstance(activation, str):
            self.activation = _get_activation_fn(activation)
        else:
            self.activation = activation

    def __setstate__(self, state):
        if 'activation' not in state:
            state['activation'] = F.relu
        super(TransformerEncoderLayer, self).__setstate__(state)

    def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        r"""Pass the input through the encoder layer.
        Args:
            src: the sequence to the encoder layer (required).
            src_mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).
        Shape:
            see the docs in Transformer class.
        """

        # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf

        x = src
        if self.norm_first:
            x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
            x = x + self._ff_block(self.norm2(x))
        else:
            x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
            x = self.norm2(x + self._ff_block(x))

        return x

    # self-attention block
    def _sa_block(self, x: Tensor,
                  attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:
        x, attn = self.self_attn(x, x, x,
                           attn_mask=attn_mask,
                           key_padding_mask=key_padding_mask,
                           need_weights=True)
        return self.dropout1(x)

    # feed forward block
    def _ff_block(self, x: Tensor) -> Tensor:
        x = self.linear2(self.dropout(self.activation(self.linear1(x))))
        return self.dropout2(x)
    
class Transformer(nn.Module):
    def __init__(self, embedding_dim=1024, feedforward_dim=512, reduction_factor=2, kdim=None, vdim=None):
        super().__init__()
        d_model = int(embedding_dim/reduction_factor)
        encoder_layer = TransformerEncoderLayer(d_model=d_model, nhead=1, dim_feedforward=feedforward_dim, kdim=kdim, vdim=vdim)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=1)
        self.dim_reducer = nn.Linear(embedding_dim, d_model)
        self.logit_scale = nn.Parameter(torch.ones([], device=device))

    def forward(self, input):
        output = input
        output = self.dim_reducer(input)
        output = output.permute(1,0,2)
        output = self.transformer_encoder(output)
        output = output.permute(1,0,2)
        return output
    


In [95]:
model = Transformer().to(device)
arr = torch.rand((1,512,1024)).to(device)
out = model(arr)

In [96]:
from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        param = parameter.numel()
        table.add_row([name, param])
        total_params+=param
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params
    
count_parameters(model)


+--------------------------------------------------------+------------+
|                        Modules                         | Parameters |
+--------------------------------------------------------+------------+
|                      logit_scale                       |     1      |
| transformer_encoder.layers.0.self_attn.in_proj_weight  |   786432   |
|  transformer_encoder.layers.0.self_attn.in_proj_bias   |    1536    |
| transformer_encoder.layers.0.self_attn.out_proj.weight |   262144   |
|  transformer_encoder.layers.0.self_attn.out_proj.bias  |    512     |
|      transformer_encoder.layers.0.linear1.weight       |   262144   |
|       transformer_encoder.layers.0.linear1.bias        |    512     |
|      transformer_encoder.layers.0.linear2.weight       |   262144   |
|       transformer_encoder.layers.0.linear2.bias        |    512     |
|       transformer_encoder.layers.0.norm1.weight        |    512     |
|        transformer_encoder.layers.0.norm1.bias         |    51

2102785

In [81]:
class MLP_1_hidden(nn.Module):
    def __init__(self, input_size = 1024, hidden_size = 512):
        super().__init__()
        self.input_size = input_size
        self.hidden_size  = hidden_size
        self.fc1 = torch.nn.Linear(self.input_size, self.hidden_size)
        self.out = torch.nn.Linear(self.hidden_size, self.hidden_size)
        self.logit_scale = nn.Parameter(torch.ones([], device=device))

    def forward(self, input):
        output = self.out(self.fc1(input))
        return output

In [82]:
linear_model = MLP_1_hidden().to(device)
count_parameters(linear_model)

+-------------+------------+
|   Modules   | Parameters |
+-------------+------------+
| logit_scale |     1      |
|  fc1.weight |   524288   |
|   fc1.bias  |    512     |
|  out.weight |   262144   |
|   out.bias  |    512     |
+-------------+------------+
Total Trainable Params: 787457


787457

In [88]:
transformer_model = Transformer(1024, 512, 4).to(device)
count_parameters(transformer_model)

+--------------------------------------------------------+------------+
|                        Modules                         | Parameters |
+--------------------------------------------------------+------------+
|                      logit_scale                       |     1      |
| transformer_encoder.layers.0.self_attn.in_proj_weight  |   196608   |
|  transformer_encoder.layers.0.self_attn.in_proj_bias   |    768     |
| transformer_encoder.layers.0.self_attn.out_proj.weight |   65536    |
|  transformer_encoder.layers.0.self_attn.out_proj.bias  |    256     |
|      transformer_encoder.layers.0.linear1.weight       |   131072   |
|       transformer_encoder.layers.0.linear1.bias        |    512     |
|      transformer_encoder.layers.0.linear2.weight       |   131072   |
|       transformer_encoder.layers.0.linear2.bias        |    256     |
|       transformer_encoder.layers.0.norm1.weight        |    256     |
|        transformer_encoder.layers.0.norm1.bias         |    25

789505

In [92]:
transformer_model = Transformer(1024, 512, 2).to(device)
count_parameters(transformer_model)

+--------------------------------------------------------+------------+
|                        Modules                         | Parameters |
+--------------------------------------------------------+------------+
|                      logit_scale                       |     1      |
| transformer_encoder.layers.0.self_attn.in_proj_weight  |   786432   |
|  transformer_encoder.layers.0.self_attn.in_proj_bias   |    1536    |
| transformer_encoder.layers.0.self_attn.out_proj.weight |   262144   |
|  transformer_encoder.layers.0.self_attn.out_proj.bias  |    512     |
|      transformer_encoder.layers.0.linear1.weight       |   262144   |
|       transformer_encoder.layers.0.linear1.bias        |    512     |
|      transformer_encoder.layers.0.linear2.weight       |   262144   |
|       transformer_encoder.layers.0.linear2.bias        |    512     |
|       transformer_encoder.layers.0.norm1.weight        |    512     |
|        transformer_encoder.layers.0.norm1.bias         |    51

2102785

In [103]:
transformer_model = Transformer(embedding_dim=1024, feedforward_dim=512, reduction_factor=2, kdim=128, vdim=128).to(device)
count_parameters(transformer_model)

+--------------------------------------------------------+------------+
|                        Modules                         | Parameters |
+--------------------------------------------------------+------------+
|                      logit_scale                       |     1      |
|  transformer_encoder.layers.0.self_attn.q_proj_weight  |   262144   |
|  transformer_encoder.layers.0.self_attn.k_proj_weight  |   28672    |
|  transformer_encoder.layers.0.self_attn.v_proj_weight  |   28672    |
|  transformer_encoder.layers.0.self_attn.in_proj_bias   |    1536    |
| transformer_encoder.layers.0.self_attn.out_proj.weight |   262144   |
|  transformer_encoder.layers.0.self_attn.out_proj.bias  |    512     |
|      transformer_encoder.layers.0.linear1.weight       |   262144   |
|       transformer_encoder.layers.0.linear1.bias        |    512     |
|      transformer_encoder.layers.0.linear2.weight       |   262144   |
|       transformer_encoder.layers.0.linear2.bias        |    51

1635841