<p align="center" width="100%">
    <img width="66%" src="https://raw.githubusercontent.com/linukc/master_dlcourse/main/images/logo.png">
</p>

 # **[MIPT DL frameworks Spring 2024](https://wiki.cogmodel.mipt.ru/s/mtai/doc/2024-nejrosetevye-frejmvorki-glubokogo-obucheniya-ZBGd69bxLd). Class 3: building deep learning blocks**

## ResNet

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride = 1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Sequential(
                        nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = stride, padding = 1),
                        nn.BatchNorm2d(out_channels),
                        nn.ReLU())
        self.conv2 = nn.Sequential(
                        nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1),
                        nn.BatchNorm2d(out_channels))
        self.downsample = downsample
        self.relu = nn.ReLU()
        self.out_channels = out_channels

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.conv2(out)
        out += residual
        out = self.relu(out)
        return out

## Inception

In [None]:
class InceptionBlock(nn.Module):
    def __init__(
        self,
        in_channels,
        out_1x1,
        red_3x3,
        out_3x3,
        red_5x5,
        out_5x5,
        out_pool,
    ):
        super(InceptionBlock, self).__init__()
        self.branch1 = ConvBlock(in_channels, out_1x1, kernel_size=1)
        self.branch2 = nn.Sequential(
            ConvBlock(in_channels, red_3x3, kernel_size=1, padding=0),
            ConvBlock(red_3x3, out_3x3, kernel_size=3, padding=1),
        )
        self.branch3 = nn.Sequential(
            ConvBlock(in_channels, red_5x5, kernel_size=1),
            ConvBlock(red_5x5, out_5x5, kernel_size=5, padding=2),
        )
        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, padding=1, stride=1),
            ConvBlock(in_channels, out_pool, kernel_size=1),
        )

    def forward(self, x):
        branches = (self.branch1, self.branch2, self.branch3, self.branch4)
        return torch.cat([branch(x) for branch in branches], dim=1)

## Depthwise Separable Convolution

In [None]:
class SeparableConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1):
        super(SeparableConv2d, self).__init__()

        self.conv1 = nn.Conv2d(
            in_channels, in_channels, kernel_size, stride, padding, dilation, groups=in_channels, bias=False)
        self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=False)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pointwise(x)
        return x

## MLP Mixer

In [None]:
from einops.layers.torch import Rearrange

In [None]:
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class MixerBlock(nn.Module):

    def __init__(self, dim, num_patch, token_dim, channel_dim, dropout = 0.):
        super().__init__()

        self.token_mix = nn.Sequential(
            nn.LayerNorm(dim),
            Rearrange('b n d -> b d n'),
            FeedForward(num_patch, token_dim, dropout),
            Rearrange('b d n -> b n d')
        )

        self.channel_mix = nn.Sequential(
            nn.LayerNorm(dim),
            FeedForward(dim, channel_dim, dropout),
        )

    def forward(self, x):

        x = x + self.token_mix(x)

        x = x + self.channel_mix(x)

        return x

## Patch Merger

In [None]:
class PatchMerger(nn.Module):
    def __init__(self, dim, num_tokens_out):
        super().__init__()
        self.scale = dim ** -0.5
        self.norm = nn.LayerNorm(dim)
        self.queries = nn.Parameter(torch.randn(num_tokens_out, dim))

    def forward(self, x):
        # x - b, n, d
        x = self.norm(x)
        sim = torch.matmul(self.queries, x.transpose(-1, -2)) * self.scale
        attn = sim.softmax(dim = -1)
        return torch.matmul(attn, x)

## SE

In [None]:
import torch
from torch import nn

class SEBlock(nn.Module):
    def __init__(
        self,
        in_channels: int,
        reduction: int = 16
    ) -> None:
        super(SEBlock, self).__init__()

        out_channels = in_channels // reduction

        self.squeeze = nn.AdaptiveAvgPool2d(1)
        self.excitation = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.ReLU(),
            nn.Conv2d(out_channels, in_channels, kernel_size=1, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        z = self.squeeze(x)  # (batch_size, in_channels, 1, 1), eq.2
        s = self.excitation(z)  # (batch_size, in_channels, 1, 1), eq.3
        out = x * s  # channel-wise multiplication, eq. 4
        return out

## Selective Kernel

In [None]:
import torch
from torch import nn
from typing import List, Optional

class SKConv(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: Optional[int] = None,
        kernels: List[int] = [3, 5],
        reduction: int = 16,
        L: int = 32,
        groups: int = 32
    ) -> None:
        super(SKConv, self).__init__()

        d = max(in_channels // reduction, L)  # eq.4

        self.M = len(kernels)

        if out_channels is None:
            out_channels = in_channels
        self.out_channels = out_channels

        self.convs = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(
                    in_channels,
                    out_channels,
                    kernel_size = k,
                    stride = 1,
                    padding = (k - 1) // 2,
                    groups = groups
                ),
                nn.BatchNorm2d(out_channels),
                nn.ReLU()
            )
            for k in kernels
        ])

        self.pool = nn.AdaptiveAvgPool2d(1)

        self.fc_z = nn.Sequential(
            nn.Linear(out_channels, d),
            nn.BatchNorm1d(d),
            nn.ReLU()
        )
        self.fc_attn = nn.Linear(d, out_channels * self.M)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # ----- split -----
        feats = torch.cat([conv(x).unsqueeze(1) for conv in self.convs], dim=1)  # (batch_size, M, out_channels, height, width)

        # ----- fuse -----
        # eq.1
        U = torch.sum(feats, dim=1)  # (batch_size, out_channels, height, width)
        # channel-wise statistics, eq.2
        s = self.pool(U).squeeze(-1).squeeze(-1)  # (batch_size, out_channels)
        # compact feature, eq.3
        z = self.fc_z(s)  # (batch_size, d)

        # ----- select -----
        batch_size, out_channels = s.shape

        # attention map, eq.5
        score = self.fc_attn(z)  # (batch_size, M * out_channels)
        score = score.view(batch_size, self.M, out_channels, 1, 1)  # (batch_size, M, out_channels, 1, 1)
        att = self.softmax(score)

        # fuse multiple branches, eq.6
        out = torch.sum(att * feats, dim=1)  # (batch_size, out_channels, height, width)
        return out

## RoPE

https://nn.labml.ai/transformers/rope/index.html

In [None]:
class RotaryPositionalEmbeddings(nn.Module):
  def __init__(self, d: int, base: int = 10_000):
      """
      * `d` is the number of features $d$
      * `base` is the constant used for calculating $\Theta$
      """
      super().__init__()

      self.base = base
      self.d = d
      self.cos_cached = None
      self.sin_cached = None

  def _build_cache(self, x: torch.Tensor):
      """
      Cache $\cos$ and $\sin$ values
      """
      # Return if cache is already built
      if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
          return

      # Get sequence length
      seq_len = x.shape[0]

      # $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
      theta = 1. / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)

      # Create position indexes `[0, 1, ..., seq_len - 1]`
      seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)

      # Calculate the product of position index and $\theta_i$
      idx_theta = torch.einsum('n,d->nd', seq_idx, theta)

      # Concatenate so that for row $m$ we have
      # $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$
      idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)

      # Cache them
      self.cos_cached = idx_theta2.cos()[:, None, None, :]
      self.sin_cached = idx_theta2.sin()[:, None, None, :]

  def _neg_half(self, x: torch.Tensor):
      # $\frac{d}{2}$
      d_2 = self.d // 2

      # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
      return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)

  def forward(self, x: torch.Tensor):
      """
      * `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]`
      """
      # Cache $\cos$ and $\sin$ values
      self._build_cache(x)

      # Split the features, we can choose to apply rotary embeddings only to a partial set of features.
      x_rope, x_pass = x[..., :self.d], x[..., self.d:]

      # Calculate
      # $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
      neg_half_x = self._neg_half(x_rope)

      # Calculate
      #
      # \begin{align}
      # \begin{pmatrix}
      # x^{(i)}_m \cos m \theta_i - x^{(i + \frac{d}{2})}_m \sin m \theta_i \\
      # x^{(i + \frac{d}{2})}_m \cos m\theta_i + x^{(i)}_m \sin m \theta_i \\
      # \end{pmatrix} \\
      # \end{align}
      #
      # for $i \in {1, 2, ..., \frac{d}{2}}$
      x_rope = (x_rope * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]])

      #
      return torch.cat((x_rope, x_pass), dim=-1)

## Vanilla Attention

In [None]:
from typing import Optional
import torch
from torch import nn
import numpy as np

In [None]:
class VanillaAttention(nn.Module):
    """
    Implementation of the attention network proposed in [1] and [2].
    Parameters
    ----------
    dim : int
        Size of the input tensor
    References
    ----------
    1. "`Neural Machine Translation by Jointly Learning to Align and Translate. \
            <https://arxiv.org/abs/1409.0473>`_" Dzmitry Bahdanau, et al. ICLR 2015.
    2. "`Effective Approaches to Attention-based Neural Machine Translation. \
            <https://arxiv.org/abs/1508.04025>`_" Minh-Thang Luong, et al. EMNLP 2015.
    """
    def __init__(
        self,
        dim: int,
    ) -> None:
        super(VanillaAttention, self).__init__()

        self.fc_align = nn.Linear(dim, dim)

        self.fc_query = nn.Linear(dim, dim)
        self.fc_value = nn.Linear(dim, dim)

        self.softmax = nn.Softmax(dim=1)
        self.tanh = nn.Tanh()


    def forward(
        self, query: torch.Tensor, key: torch.Tensor) -> torch.Tensor:
        """
        Parameters
        ----------
        query : torch.Tensor (batch_size, dim)
            Query
        key : torch.Tensor (batch_size, length, dim)
            Key
        Returns
        -------
        out : torch.Tensor (batch_size, dim)
            Output tensor
        att: torch.Tensor (batch_size, length)
            Attention weights
        """

        # alignment scores
        score = self.fc_align(query)  # (batch_size, dim)
        score = (key @ score.unsqueeze(2)).squeeze(2)  # (batch_size, length)

        # attention weights
        att = self.softmax(score)  # (batch_size, length)

        # context vector (weighted value)
        context = (att.unsqueeze(1) @ key).squeeze(1)  # (batch_size, dim)

        # attention result
        out = self.tanh(self.fc_value(context) + self.fc_query(query))

        return out, att

## LSTM

In [None]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np

class LSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size, bias=True):
        super(LSTMCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias

        self.xh = nn.Linear(input_size, hidden_size * 4, bias=bias)
        self.hh = nn.Linear(hidden_size, hidden_size * 4, bias=bias)

    def forward(self, input, hx=None):

        # Inputs:
        #       input: of shape (batch_size, input_size)
        #       hx: of shape (batch_size, hidden_size)
        # Outputs:
        #       hy: of shape (batch_size, hidden_size)
        #       cy: of shape (batch_size, hidden_size)

        hx, cx = hx

        gates = self.xh(input) + self.hh(hx)

        # Get gates (i_t, f_t, g_t, o_t)
        input_gate, forget_gate, cell_gate, output_gate = gates.chunk(4, 1)

        i_t = torch.sigmoid(input_gate)
        f_t = torch.sigmoid(forget_gate)
        g_t = torch.tanh(cell_gate)
        o_t = torch.sigmoid(output_gate)

        cy = cx * f_t + i_t * g_t

        hy = o_t * torch.tanh(cy)


        return (hy, cy)

## Attention

In [None]:
from typing import Tuple, Optional
import torch
from torch import nn
import numpy as np

In [None]:
def split_heads(x: torch.Tensor, n_heads: int) -> torch.Tensor:
    batch_size, dim = x.size(0), x.size(-1)
    x = x.view(batch_size, -1, n_heads, dim // n_heads)  # (batch_size, length, n_heads, d_head)
    x = x.transpose(1, 2)  # (batch_size, n_heads, length, d_head)
    return x

def combine_heads(x: torch.Tensor) -> torch.Tensor:
    batch_size, n_heads, d_head = x.size(0), x.size(1), x.size(3)
    x = x.transpose(1, 2).contiguous().view(batch_size, -1, d_head * n_heads)  # (batch_size, length, n_heads * d_head)
    return x

def add_mask(x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
    if mask is not None:
        if len(x.size()) == 4:
            expanded_mask = mask.unsqueeze(1).unsqueeze(1)  # (batch_size, 1, 1, length)
        x = x.masked_fill(expanded_mask.bool(), -np.inf)
    return

In [None]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self, scale: float, dropout: float = 0.5) -> None:
        super(ScaledDotProductAttention, self).__init__()

        self.scale = scale
        self.softmax = nn.Softmax(dim=-1)
        self.dropout = None if dropout is None else nn.Dropout(dropout)

    def forward(
        self,
        Q: torch.Tensor,
        K: torch.Tensor,
        V: torch.Tensor,
        mask: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor]:
        # Q·K^T / sqrt(d_head)
        score = torch.matmul(Q / self.scale, K.transpose(2, 3))  # (batch_size, n_heads, length, length)
        score = add_mask(score, mask)

        # eq.1: Attention(Q, K, V) = softmax(Q·K^T / sqrt(d_head))·V
        att = self.softmax(score)  # (batch_size, n_heads, length, length)
        att = att if self.dropout is None else self.dropout(att)
        context = att @ V  # (batch_size, n_heads, length, d_head)

        return context, att

In [None]:
class SelfAttention(nn.Module):
    def __init__(
        self,
        dim: int,
        n_heads: int = 8,
        dropout: Optional[float] = None
    ) -> None:
        super(SelfAttention, self).__init__()

        assert dim % n_heads == 0

        self.n_heads = n_heads
        self.d_head = dim // n_heads

        # linear projections
        self.W_Q = nn.Linear(dim, n_heads * self.d_head)
        self.W_K = nn.Linear(dim, n_heads * self.d_head)
        self.W_V = nn.Linear(dim, n_heads * self.d_head)

        # scaled dot-product attention
        scale = self.d_head ** 0.5  # scale factor
        self.attention = ScaledDotProductAttention(scale=scale, dropout=dropout)

        self.layer_norm = nn.LayerNorm(dim)
        self.fc = nn.Linear(n_heads * self.d_head, dim)

        self.dropout = None if dropout is None else nn.Dropout(dropout)

    def forward(
        self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor]:
        Q = self.W_Q(x)  # (batch_size, length, n_heads * d_head)
        K = self.W_K(x)
        V = self.W_V(x)

        Q, K, V = split_heads(Q, self.n_heads), split_heads(K, self.n_heads), split_heads(V, self.n_heads)
        # (batch_size, n_heads, length, d_head)

        context, _ = self.attention(Q, K, V, mask=mask)  # (batch_size, n_heads, length, d_head)
        context = combine_heads(context)  # (batch_size, length, n_heads * d_head)

        out = self.fc(context)  # (batch_size, length, dim)
        out = out if self.dropout is None else self.dropout(out)

        out = out + x  # residual connection
        out = self.layer_norm(out)  # LayerNorm

        return out