# Understanding LongNet

## Pre-requisites

#### 1. Matrix multiplication

First we need to understand how matrix mulltiplication work.
2d matrix multiplication is N X M can be multiplied with M X P. Meaning both must have common number of dimensions across an axis
The result would be N X P dimensions

For N- Dimension matrix multiplication, the formula is 
If both arguments are 2-D they are multiplied like conventional matrices.
If either argument is N-D, N > 2, it is treated as a stack of matrices residing in the last two indexes and broadcast accordingly.

Ex:

Matrix multiplication of 2 matrices with more than 2 dimensions works like this
Consider 3X2X5 and 3X5X4 are the 2 matrices
so it means in the 1st matrix, there are 3 2X5 2D matrices. Likewise in the 2nd matrix, there are 3 5X4 2D matrices. So multiplying 2D matrices will be (2 X 5) X (5 X 4) = (2 X 4)

The 3rd dimension must be same or can be broadcasted. Ex: Now in this case, valid dimensions are (3x2x5, 3x5x4) or (3x2x5, 1x5x4)  or (1x2x5, 3x5x4). Any of these sizes on multiplication gives 3x2x4. If 1 is there, it means it will be broadcasted to other matrix dimension

Similarly lets take 5d arrays
a = np.ones([6, 2, 4, 7, 4])
c = np.ones([6, 1, 1, 4, 3])

on multiplication, it gives an array of size (6, 2, 4, 7, 3) because 1 on 2nd dim of c gets broadcasted to 2. 1 on 3rd dimension of c gets broadcasted to 4.

So either the other dimensions other than last 2 dimensions should be same or broadcastable and last 2 dimension must follow N X M , M X P -> N X P rule




#### 2. Attention Mechanisms

Explanation: [TODO]
Types: [TODO]
How to calculate complexity : [TODO]

# Additive attention:
Explanation : [#TODO]

In [58]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class AdditiveAttention(nn.Module):
    def __init__(self, hidden_dim):
        super(AdditiveAttention, self).__init__()
        self.hidden_dim = hidden_dim

        # These will store the learnable weights of our attention mechanism
        self.query_layer = nn.Linear(hidden_dim, hidden_dim)
        self.key_layer = nn.Linear(hidden_dim, hidden_dim)
        self.energy_layer = nn.Linear(hidden_dim, 1)

    def forward(self, query, key, value, mask=None):
        # Query shape: [batch_size, 1, hidden_dim]
        # Key and Value shapes: [batch_size, seq_len, hidden_dim]

        query = self.query_layer(query)
        key = self.key_layer(key)

        # Calculate the attention energies
        energies = self.energy_layer(F.tanh(query + key))
        print(f"SCORES is {energies.shape}")

        if mask is not None:
            energies = energies.masked_fill(mask == 0, -1e10)

        # Convert energies to attention probabilities
        attention = F.softmax(energies, dim=1)
        print(f"ATTENTION is {attention.shape}")

        # Calculate the context vector
        context = torch.bmm(attention.transpose(1, 2), value)
        print(f"OUTPUT is {context.shape}")
        return context, attention




In [60]:
# Example usage:
attention = AdditiveAttention(hidden_dim=128)
out = attention(torch.randn(size=(32,5,128)), torch.randn(size=(32,5,128)), torch.randn(size=(32,5,128)))

SCORES is torch.Size([32, 5, 1])
ATTENTION is torch.Size([32, 5, 1])
OUTPUT is torch.Size([32, 1, 128])


# Dot Product Attention:
Explanation: [TODO]

In [50]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Attention(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, query, key, value, mask=None):
        # Query, key, value shapes: [batch_size, seq_len, d_k]
        scores = torch.matmul(query, key.transpose(-2, -1))
        print(f"SCORES is {scores.shape}")
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attention = F.softmax(scores, dim=-1)
        print(f"ATTENTION is {attention.shape}")

        # Apply the attention to the values
        output = torch.matmul(attention, value)
        print(f"OUTPUT is {output.shape}")
        return output, attention

In [54]:
attention = Attention()
out = attention(torch.randn(size=(32,5,128)), torch.randn(size=(32,5,128)), torch.randn(size=(32,5,128)))

SCORES is torch.Size([32, 5, 5])
ATTENTION is torch.Size([32, 5, 5])
OUTPUT is torch.Size([32, 5, 128])


# Scaled attention
Explanation: [TODO]

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_k):
        super().__init__()
        self.scaling = 1. / d_k**0.5

    def forward(self, query, key, value, mask=None):
        # Query, key, value shapes: [batch_size, n_heads, seq_len, d_k]
        scores = torch.matmul(query, key.transpose(-2, -1)) * self.scaling
        print(f"SCORES is {scores.shape}")
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attention = F.softmax(scores, dim=-1)
        print(f"ATTENTION is {attention.shape}")

        # Apply the attention to the values
        output = torch.matmul(attention, value)
        print(f"OUTPUT is {output.shape}")
        return output, attention


In [5]:
# Use the module
d_k = 64
attention_module = ScaledDotProductAttention(d_k)

In [7]:
out = attention_module(torch.randn(size=(32,1,5,128)), torch.randn(size=(32,1,5,128)), torch.randn(size=(32,1,5,128)))

SCORES is torch.Size([32, 1, 5, 5])
ATTENTION is torch.Size([32, 1, 5, 5])
OUTPUT is torch.Size([32, 1, 5, 128])


# Sparse attention
Explanation: [TODO]