In [4]:
import math
import torch # type: ignore
import numpy as np # type: ignore
import matplotlib.pyplot as plt # type: ignoreb


nn = torch.nn
F = torch.nn.functional

### Dot product attention, also known as softmax attention, or flash-attention (cuda version)Flash Attention:- https://arxiv.org/abs/2407.08608- https://github.com/dao-ailab/flash-attention- https://www.stephendiehl.com/posts/flash_attention/- DeepSeek MultiHeadLatentAttention (Groupped Attention)

$Attention(X) = Softmax(\frac{Q \cdot K^{\top}}{\sqrt(d_{k})} + M) \cdot V$


$d_{k} = Q.size(-1) \rightarrow$ head dimension $\rightarrow$ number of heads x head dimension = embedding dimension

e.g., if embedding dimension = 64 then we have 2 options 2 heads x 32 head dimension, or, 8 heads x 8 head dimension

$Q = X \cdot W_{q}$

$K = X \cdot W_{k}$

$V = X \cdot W_{v}$

In [None]:
def dot_prod_attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    # pdb.set_trace()
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        """
        mask = [[0, -inf, -inf],
                [0, 0, -inf],
                [0, 0, 0]]     
        """ 
        scores = scores + mask
    p_attn = F.softmax(scores, dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn


In [None]:

def linformer_attention(Q, K, V, E, mask):
    head_dim = Q.size(-1)
    K = torch.matmul(E, K * mask[:, None, :, None])
    V = torch.matmul(E, V * mask[:, None, :, None])
    dot = torch.matmul(Q, torch.transpose(K, -2, -1))
    dot = dot / math.sqrt(head_dim)
    attn = F.softmax(dot, dim=-1)
    X = torch.matmul(attn, V)
    return X

In [None]:
def aft_attention(self, x, d_model,  n=49, simple=False):
    """Attention Free Transformer"""
    B, N, D = x.shape
    if simple:
        self.position_biases = torch.zeros((n, n))
    else:
        self.position_biases = nn.Parameter(torch.ones((n, n)))
    q = nn.Linear(d_model, d_model)(x)
    k = nn.Linear(d_model, d_model)(x).view(1, B, N, D)
    v = nn.Linear(d_model, d_model)(x).view(1, B, N, D)
    numerator = torch.sum(torch.exp(k + self.position_biases.view(N, 1, -1, 1)) * v, dim=2)
    denominator = torch.sum(torch.exp(k + self.position_biases.view(N, 1, -1, 1)), dim=2)
    out = numerator / denominator
    out = torch.sigmoid(q) * (out.permute(1, 0, 2))
    return out

In [None]:
import math
import torch  # type: ignore
from src.components.utils import clones
# from transformers.modeling_reformer import LSHSelfAttention, ReformerConfig


nn = torch.nn
F = torch.nn.functional


class SoftmaxAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.drop_attn = torch.nn.Dropout(p=config["dropout_prob"])
        self.head_dim = config["head_dim"]

    def forward(self, Q, K, V, mask):
        dot = torch.matmul(Q, torch.transpose(K, -2, -1))
        dot = dot / math.sqrt(self.head_dim)
        dot = dot - 1e6 * (1 - mask[:, None, None, :])

        attn = F.softmax(dot, dim=-1)
        attn = self.drop_attn(attn)

        X = torch.matmul(attn, V)
        return X


class MultiHeadedAttention(nn.Module):
    def __init__(
        self,
        h,
        d_model,
        dropout=0.1,
        bias=True,
        freeze_q=False,
        freeze_k=False,
        freeze_v=False,
        zero_k=False,
    ):
        "Take in model size and number of heads."
        super(MultiHeadedAttention, self).__init__()
        self.bias = bias
        assert d_model % h == 0
        # We assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model, bias=bias), 4)
        if freeze_q:
            self.linears[0].requires_grad_(False)
        if freeze_k:
            self.linears[1].requires_grad_(False)
        if freeze_v:
            self.linears[2].requires_grad_(False)
        if zero_k:
            self.null_linear_layer(self.linears[1])
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)

    def null_linear_layer(self, ln):
        with torch.no_grad():
            ln.weight.fill_(0.0)
            if self.bias:
                ln.bias.fill_(0.0)
        ln.requires_grad_(False)

    def forward(self, query, key, value, mask=None):
        "Implements Figure 2"
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(0).unsqueeze(0)
        nbatches = query.size(0)

        # 1) Do all the linear projections in batch from d_model => h x d_k
        query, key, value = [
            layer(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
            for layer, x in zip(self.linears, (query, key, value))
        ]

        # 2) Apply attention on all the projected vectors in batch.
        x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)

        # 3) "Concat" using a view and apply a final linear.
        x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
        return self.linears[-1](x)


class LinformerAttention(nn.Module):
    projection_matrix = None

    def __init__(self, config):
        super().__init__()

        self.num_head = config["num_head"]
        self.head_dim = config["head_dim"]
        self.linformer_k = config["linformer_k"]
        self.seq_len = config["max_seq_len"]

        if LinformerAttention.projection_matrix is not None:
            self.E = LinformerAttention.projection_matrix
        else:
            LinformerAttention.projection_matrix = nn.Parameter(torch.Tensor(self.num_head, self.linformer_k, self.seq_len))
            nn.init.normal_(LinformerAttention.projection_matrix, std=0.02)
            self.E = LinformerAttention.projection_matrix

    def forward(self, Q, K, V, mask):
        K = torch.matmul(self.E, K * mask[:, None, :, None])
        V = torch.matmul(self.E, V * mask[:, None, :, None])

        dot = torch.matmul(Q, torch.transpose(K, -2, -1))
        dot = dot / math.sqrt(self.head_dim)

        attn = F.softmax(dot, dim=-1)

        X = torch.matmul(attn, V)

        return X


class NystromAttention(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.head_dim = config["head_dim"]
        self.num_head = config["num_head"]

        self.num_landmarks = config["num_landmarks"]
        self.seq_len = config["seq_len"]

        if "inv_coeff_init_option" in config:
            self.init_option = config["inv_init_coeff_option"]
        else:
            self.init_option = "original"

        self.use_conv = "conv_kernel_size" in config
        if self.use_conv:
            self.conv = nn.Conv2d(
                in_channels=self.num_head,
                out_channels=self.num_head,
                kernel_size=(config["conv_kernel_size"], 1),
                padding=(config["conv_kernel_size"] // 2, 0),
                bias=False,
                groups=self.num_head
            )

    def forward(self, Q, K, V, mask):

        Q = Q * mask[:, None, :, None] / math.sqrt(math.sqrt(self.head_dim))
        K = K * mask[:, None, :, None] / math.sqrt(math.sqrt(self.head_dim))

        if self.num_landmarks == self.seq_len:
            attn = F.softmax(torch.matmul(Q, K.transpose(-1, -2)) - 1e9 * (1 - mask[:, None, None, :]), dim=-1)
            X = torch.matmul(attn, V)
        else:
            Q_landmarks = Q.reshape(-1, self.num_head, self.num_landmarks, self.seq_len // self.num_landmarks, self.head_dim).mean(dim=-2)
            K_landmarks = K.reshape(-1, self.num_head, self.num_landmarks, self.seq_len // self.num_landmarks, self.head_dim).mean(dim=-2)

            kernel_1 = F.softmax(torch.matmul(Q, K_landmarks.transpose(-1, -2)), dim=-1)
            kernel_2 = F.softmax(torch.matmul(Q_landmarks, K_landmarks.transpose(-1, -2)), dim=-1)
            kernel_3 = F.softmax(torch.matmul(Q_landmarks, K.transpose(-1, -2)) - 1e9 * (1 - mask[:, None, None, :]), dim=-1)
            X = torch.matmul(torch.matmul(kernel_1, self.iterative_inv(kernel_2)), torch.matmul(kernel_3, V))

        if self.use_conv:
            X += self.conv(V * mask[:, None, :, None])

        return X

    def iterative_inv(self, mat, n_iter=6):
        I = torch.eye(mat.size(-1), device=mat.device)
        K = mat

        # The entries of K are positive and ||K||_{\infty} = 1 due to softmax
        if self.init_option == "original":
            # This original implementation is more conservative to compute coefficient of Z_0. 
            V = 1 / torch.max(torch.sum(K, dim=-2)) * K.transpose(-1, -2)
        else:
            # This is the exact coefficient computation, 1 / ||K||_1, of initialization of Z_0, leading to faster convergence. 
            V = 1 / torch.max(torch.sum(K, dim=-2), dim=-1).values[:, :, None, None] * K.transpose(-1, -2)

        for _ in range(n_iter):
            KV = torch.matmul(K, V)
            V = torch.matmul(0.25 * V, 13 * I - torch.matmul(KV, 15 * I - torch.matmul(KV, 7 * I - KV)))
        return V

    def extra_repr(self):
        return f'num_landmarks={self.num_landmarks}, seq_len={self.seq_len}'


# class LSHAttention(LSHSelfAttention):
#     """
#     LSH Attention - Reformer Attention
#     """
#     def __init__(self, config, query, key, value):
#         reformer_config = ReformerConfig()
#         reformer_config.attn_layers = ["lsh"]
#         reformer_config.num_hashes = config["num_hash"]
#         reformer_config.is_decoder = False
#         reformer_config.max_position_embeddings = config["max_seq_len"]
#         reformer_config.hidden_size = config["transformer_dim"]
#         super().__init__(reformer_config)
#         self.query_key.weight = query.weight
#         self.value.weight = value.weight

#     def forward(self, X, mask):
#         return super().forward(hidden_states=X, attention_mask=mask).hidden_states


class AFTAttention(nn.Module):  # AFT Attention
    """
    AFT Attention
    refer to "An Attention Free Transformer"
    """
    def __init__(self, d_model, n=49, simple=False):
        super().__init__()
        self.fc_q = nn.Linear(d_model, d_model)
        self.fc_k = nn.Linear(d_model, d_model)
        self.fc_v = nn.Linear(d_model, d_model)
        if simple:
            self.position_biases = torch.zeros((n, n))
        else:
            self.position_biases = nn.Parameter(torch.ones((n, n)))
        self.d_model = d_model
        self.n = n
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        B, N, D = x.shape

        q = self.fc_q(x)
        k = self.fc_k(x).view(1, B, N, D)
        v = self.fc_v(x).view(1, B, N, D)

        numerator = torch.sum(torch.exp(k + self.position_biases.view(N, 1, -1, 1)) * v, dim=2)
        denominator = torch.sum(torch.exp(k + self.position_biases.view(N, 1, -1, 1)), dim=2)

        out = numerator / denominator
        out = self.sigmoid(q) * (out.permute(1, 0, 2))
        return out


# if __name__ == "__main__":
#     dummy_input = ms.ops.randn((50, 49, 512))
#     aft_full = AFT_FULL(d_model=512, n=49)
#     output = aft_full(dummy_input)
#     print(output.shape)


# Relative Position Attention
"""
An illustrated explanation and code can be found at:
https://nn.labml.ai/transformers/xl/relative_mha.html
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/transformers/xl/relative_mha.py
https://github.com/evelinehong/Transformer_Relative_Position_PyTorch/blob/master/relative_position.py
https://medium.com/@rajveer.rathod1301/relative-positional-multi-head-attention-an-overview-e7d22a63e01c
"""
# class RelativePosition(nn.Module):

#     def __init__(self, num_units, max_relative_position):
#         super().__init__()
#         self.num_units = num_units
#         self.max_relative_position = max_relative_position
#         self.embeddings_table = Parameter(torch.Tensor(max_relative_position * 2 + 1, num_units)
#         nn.init.xavier_uniform_(self.embeddings_table)

#     def forward(self, length_q, length_k):
#         range_vec_q = torch.arange(length_q)
#         range_vec_k = torch.arange(length_k)
#         distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
#         distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)
#         final_mat = distance_mat_clipped + self.max_relative_position
#         final_mat = torch.LongTensor(final_mat).cuda()
#         embeddings = self.embeddings_table[final_mat].cuda()

#         return embeddings


# self.relative_position_k = RelativePosition(i, self.d_k, max_relative_position)
# self.relative_position_v = RelativePosition(i, self.d_v, max_relative_position)

# r_q = q.permute(2, 0, 1, 3).contiguous().view(len_q, sz_b*n_head, d_k)
# r_k = self.relative_position_k(len_q, len_k)
# attn_2 = torch.matmul(r_q, r_k.transpose(1, 2)).transpose(0, 1)
# attn_2 = attn_2.contiguous().view(sz_b, self.n_head, len_k, len_k)

# r_v = self.relative_position_v(len_q, len_v)
# weight = attn.permute(2, 0, 1, 3).contiguous().view(len_q, sz_b*n_head, len_k)
# weight = torch.matmul(weight, r_v)
# weight = weight.transpose(0, 1).contiguous().view(sz_b, self.n_head, len_q, d_v)


"""
Switch Transformers like MoE, Routing Transformer in the FFN layer
https://github.com/kyegomez/SwitchTransformers/blob/main/switch_transformers/model.py
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/transformers/switch/__init__.py
"""

"""
To read https://pytorch.org/blog/flexattention/
"""


# type = {}

# with open("/model/config.json", "r") as f:
#     config = json.load(f)
# model_config = config["model"]

# if model_config["attn_type"] == "softmax":
#     type["softmax"] = SoftmaxAttention

# elif model_config["attn_type"] == "nystrom":
#     from attention_nystrom import NystromAttention
#     type["nystrom"] = NystromAttention

# elif model_config["attn_type"] == "reformer":
#     from attention_reformer import LSHAttention
#     type["reformer"] = LSHAttention

# elif model_config["attn_type"] == "linformer":
#     from attention_linformer import LinformerAttention
#     type["linformer"] = LinformerAttention

# else:
#     raise Exception()


# Config for NystromModel 512
# type = {
#     "model_checkpoints": "/model/model",
#     "data_folder": "/dataset",
#     "glue_dataset_folder": "/glue",
#     "wikihop_dataset_folder": "/wikihop",
#     "model": {
#         "mixed_precision": true,
#         "attention_grad_checkpointing": false,
#         "gelu_grad_checkpointing": true,
#         "vocab_size": 50265,
#         "num_sen_type": 1,
#         "max_seq_len": 512,
#         "embedding_dim": 768,
#         "transformer_dim": 768,
#         "transformer_hidden_dim": 3072,
#         "num_layers": 12,
#         "dropout_prob": 0.1,
#         "num_head": 12,
#         "head_dim": 64,
#         "attn_type": "nystrom",
#         "num_landmarks": 512,
#         "seq_len": 64,
#         "conv_kernel_size": 33
#     },
#     "pretraining_setting": {
#         "batch_size": 256,
#         "learning_rate": 0.0001,
#         "warmup": 0.01,
#         "batches_per_report": 10,
#         "batches_per_epoch": 5000,
#         "epoch": 100,
#         "validate_batches_per_epoch": 100
#     },
#     "gpu_setting": {
#         "inst_per_gpu": 8
#     }
# }

import math

import torch
from torch import nn
from torch.nn import functional as F

from alibi.config import ALiBiConfig


def get_relative_positions(seq_len: int) -> torch.tensor:
    x = torch.arange(seq_len)[None, :]
    y = torch.arange(seq_len)[:, None]
    return x - y


def get_alibi_slope(num_heads):
    x = (2 ** 8) ** (1 / num_heads)
    return (
        torch.tensor([1 / x ** (i + 1) for i in range(num_heads)])
        .unsqueeze(-1)
        .unsqueeze(-1)
    )


class ALiBiMultiHeadAttention(nn.Module):
    """
    from https://github.com/jaketae/alibi/blob/main/alibi/attention.py
    """
    def __init__(self, config: ALiBiConfig) -> None:
        super().__init__()
        self.causal = config.causal
        self.num_heads = config.num_heads
        self.scale = math.sqrt(config.d_model)
        self.dropout = nn.Dropout(config.dropout)
        self.register_buffer("m", get_alibi_slope(self.num_heads))
        self.kqv = nn.Linear(config.d_model, 3 * config.d_model, bias=False)
        if config.causal:
            self.register_buffer(
                "mask", torch.tril(torch.ones(1, 1, config.max_len, config.max_len))
            )

    def forward(self, x: torch.tensor) -> torch.tensor:
        batch_size, seq_len, _ = x.shape

        key, query, value = self.kqv(x).chunk(3, dim=-1)
        key = key.view(batch_size, seq_len, self.num_heads, -1).permute(0, 2, 3, 1)
        # key.shape == (batch_size, num_heads, d_head, seq_len)
        query = query.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
        value = value.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
        # qv.shape == (batch_size, num_heads, seq_len, d_head)

        bias = (self.m * get_relative_positions(seq_len)).unsqueeze(0)
        # bias.shape == (1, num_heads, seq_len, seq_len)

        score = torch.matmul(query, key) / self.scale + bias
        # score.shape == (batch_size, num_heads, seq_len, seq_len)

        if self.causal:
            score = score.masked_fill(
                self.mask[:, :, :seq_len, :seq_len] == 0, float("-inf")
            )

        attn = F.softmax(score, dim=-1)
        out = torch.matmul(attn, value)
        # out.shape == (batch_size, num_heads, seq_len, d_head)
        out = out.transpose(1, 2).reshape(batch_size, seq_len, -1)
        # out.shape == (batch_size, seq_len, d_model)
        out = self.dropout(out)

        return out