<!DOCTYPE html>
<html>
<head>
<style>
    body {
        font-family: Arial, sans-serif; /* Change the font to Arial */
    }
    
    .red-text {
        color: red; /* Change the text color to red */
    }
    
    .blue-text {
        color: blue; /* Change the text color to blue */
    }
</style>
</head>
<body>


<div>
    <div style="width: 200; float: left;">
        <h1 style="color: #591101;  font-family: Arial, sans-serif; font-weight: 900;">Perceiver Explained</h1>
        <p style="color: #878787;  font-family: Arial, sans-serif; font-weight: 700;">Perceiver: General Perception with Iterative Attention</p>
        <hr>
        <h4>Author:</h4>
        <pre>➢ Clint Morris</pre>
        <h4>Publication:</h4>
        <pre><a href="https://arxiv.org/pdf/2103.03206.pdf">➢ arXiv:2103.03206</a></pre>
        <h4>Code Base:</h4>
        <pre><a href="https://github.com/lucidrains/perceiver-pytorch/blob/main/perceiver_pytorch/perceiver_pytorch.py">➢ Pytorch - Phil Wang</a>
<a href="https://github.com/lucidrains/perceiver-pytorch/blob/main/perceiver_pytorch/perceiver_pytorch.py">➢ Tensorflow - Clint Morris</a></pre>
            <br><br>
            <img src="https://i.ibb.co/k9dtKFm/Capture.png"  width="450">
            <img src="https://i.ibb.co/gTtgNfY/My-project.png" width="120" style="position: absolute; bottom: 0; left: 0;">
    </div>
    <div style="width: 500; float: right;">
        <img src="https://s10.gifyu.com/images/demo38de86fabd82634d.gif"  width="500">
    </div>
</div>

<br><br>
<h2>Stage 1 - Cross Attention</h2>
<hr>
<img src="https://i.ibb.co/Y7mtjh9/perceiver-drawio2-Copy.png"  width="800">

<h3>Stage 1.1 - PreNorm</h3>

The PreNorm class is a custom PyTorch module that applies layer normalization to the input (and optionally to the context) before passing it to the specified function or layer. This can be useful in improving the training process of deep learning models by normalizing the input across features.

In [None]:
class PreNorm(nn.Module):
    def __init__(self, dim, fn, context_dim = None):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)
        self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None

    def forward(self, x, **kwargs):
        x = self.norm(x)

        if exists(self.norm_context):
            context = kwargs['context']
            normed_context = self.norm_context(context)
            kwargs.update(context = normed_context)

        return self.fn(x, **kwargs)

<h3>Stage 1.2 - Attention</h3>

The PreNorm class is a custom PyTorch module that applies layer normalization to the input (and optionally to the context) before passing it to the specified function or layer. This can be useful in improving the training process of deep learning models by normalizing the input across features.

In [None]:
class Attention(nn.Module):
    def __init__(self, query_dim, context_dim = None, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head * heads
        context_dim = default(context_dim, query_dim)

        self.scale = dim_head ** -0.5
        self.heads = heads

        self.to_q = nn.Linear(query_dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)

        self.dropout = nn.Dropout(dropout)
        self.to_out = nn.Linear(inner_dim, query_dim)

    def forward(self, x, context = None, mask = None):
        h = self.heads

        q = self.to_q(x)
        context = default(context, x)
        k, v = self.to_kv(context).chunk(2, dim = -1)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))

        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale

        if exists(mask):
            mask = rearrange(mask, 'b ... -> b (...)')
            max_neg_value = -torch.finfo(sim.dtype).max
            mask = repeat(mask, 'b j -> (b h) () j', h = h)
            sim.masked_fill_(~mask, max_neg_value)

        # attention, what we cannot get enough of
        attn = sim.softmax(dim = -1)
        attn = self.dropout(attn)

        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
        return self.to_out(out)

<h4>Stage 1.2.1 - chunk demo</h4>

In [4]:
import torch

# Create a 2x4 tensor
tensor = torch.tensor([
    [1, 2, 3, 4],
    [5, 6, 7, 8]
])

# Split the tensor into 2 chunks along the last dimension (columns)
chunks = torch.chunk(tensor, chunks=2, dim=-1)

# Print the original tensor and the resulting chunks
print("Original Tensor:")
print(tensor)
for i, chunk in enumerate(chunks):
    print(f"Chunk {i + 1}:")
    print(chunk)

Original Tensor:
tensor([[1, 2, 3, 4],
        [5, 6, 7, 8]])
Chunk 1:
tensor([[1, 2],
        [5, 6]])
Chunk 2:
tensor([[3, 4],
        [7, 8]])


<h4>Stage 1.2.2 - rearrange demo</h4>

In [27]:
import torch
from einops import rearrange

# Input tensor shape: (batch_size, num_items, num_groups * group_dim)
tensor = torch.tensor([
    [
        [1, 2, 3, 4],
        [5, 6, 7, 8]
    ],
    [
        [9, 10, 11, 12],
        [13, 14, 15, 16]
    ]
])

# Parameters:
batch_size = 2
num_items = 2
num_groups = 2
group_dim = 2

# Rearrange the tensor using the specified pattern
output = rearrange(tensor, 'b n (h d) -> (b h) n d', h=num_groups)

# Print the input tensor and the rearranged output tensor
print("Input Tensor (shape: {}):".format(tensor.shape))
print(tensor)
print("\nRearranged Output Tensor (shape: {}):".format(output.size()))
print(output)

Input Tensor (shape: torch.Size([2, 2, 4])):
tensor([[[ 1,  2,  3,  4],
         [ 5,  6,  7,  8]],

        [[ 9, 10, 11, 12],
         [13, 14, 15, 16]]])

Rearranged Output Tensor (shape: torch.Size([4, 2, 2])):
tensor([[[ 1,  2],
         [ 5,  6]],

        [[ 3,  4],
         [ 7,  8]],

        [[ 9, 10],
         [13, 14]],

        [[11, 12],
         [15, 16]]])


In [29]:
import torch

# Input tensor shape: (batch_size, num_items, num_groups * group_dim)
tensor = torch.tensor([
                        [[1, 2, 3, 4],
                         [5, 6, 7, 8]],
                        [[9, 10, 11, 12],
                         [13, 14, 15, 16]]
                      ])

# Parameters:
batch_size = 2
num_items = 2
num_groups = 2
group_dim = 2

# Rearrange the tensor using PyTorch view() and permute()
reshaped_tensor = tensor.view(batch_size, num_items, num_groups, group_dim)
output = reshaped_tensor.permute(0, 2, 1, 3).contiguous().view(batch_size * num_groups, num_items, group_dim)

# Print the input tensor and the rearranged output tensor
print("Input shape: {}:".format(tensor.shape))
print(tensor)
print("\nRearranged Output shape: {}:".format(output.shape))
print(output)

Input shape: torch.Size([2, 2, 4]):
tensor([[[ 1,  2,  3,  4],
         [ 5,  6,  7,  8]],

        [[ 9, 10, 11, 12],
         [13, 14, 15, 16]]])

Rearranged Output shape: torch.Size([4, 2, 2]):
tensor([[[ 1,  2],
         [ 5,  6]],

        [[ 3,  4],
         [ 7,  8]],

        [[ 9, 10],
         [13, 14]],

        [[11, 12],
         [15, 16]]])


<h4>Stage 1.2.3 - einsum demo</h4>

In [33]:
import torch

# Input tensors
attn = torch.tensor([
    [
        [0.5, 0.5],
        [0.3, 0.7]
    ],
    [
        [0.2, 0.8],
        [0.4, 0.6]
    ]
])

print(f'attn shape: {attn.shape}')

v = torch.tensor([
    [
        [1., 2., 3.],
        [3., 4., 5.]
    ],
    [
        [5., 6., 8.],
        [7., 8., 9.]
    ]
])

print(f'value shape: {v.shape}\n')

# Perform the einsum operation
result = torch.einsum('b i j, b j d -> b i d', attn, v)

print("Result shape: {}:".format(result.shape))
print(result)

attn shape: torch.Size([2, 2, 2])
value shape: torch.Size([2, 2, 3])

Result shape: torch.Size([2, 2, 3]):
tensor([[[2.0000, 3.0000, 4.0000],
         [2.4000, 3.4000, 4.4000]],

        [[6.6000, 7.6000, 8.8000],
         [6.2000, 7.2000, 8.6000]]])


In [34]:
attn = torch.tensor([
    [
        [0.5, 0.5],
        [0.3, 0.7],
        [0.4, 0.6],
        [0.4, 0.6]
    ],
    [
        [0.2, 0.8],
        [0.4, 0.6],
        [0.4, 0.6],
        [0.4, 0.6]
    ]
])

print(f'attn shape: {attn.shape}')

v = torch.tensor([
    [
        [3.],
        [5.]
    ],
    [
        [8.],
        [9.]
    ]
])

print(f'value shape: {v.shape}\n')

# Perform the einsum operation
result = torch.einsum('b i j, b j d -> b i d', attn, v)

print("Result Tensor (shape: {}):".format(result.shape))
print(result)

attn shape: torch.Size([2, 4, 2])
value shape: torch.Size([2, 2, 1])

Result Tensor (shape: torch.Size([2, 4, 1])):
tensor([[[4.0000],
         [4.4000],
         [4.2000],
         [4.2000]],

        [[8.8000],
         [8.6000],
         [8.6000],
         [8.6000]]])


In [None]:
get_cross_attn = lambda: PreNorm(latent_dim, 
                                 Attention(latent_dim, input_dim, heads = cross_heads, dim_head = cross_dim_head, 
                                           dropout = attn_dropout), 
                                 context_dim = input_dim)

<br><br>
<h2>Stage 3 - Cross Attention FFN</h2>
<hr>
<img src="https://i.ibb.co/th9399Z/perceiver-drawio2-1.png"  width="300">

In [None]:
class GEGLU(nn.Module):
    def forward(self, x):
        x, gates = x.chunk(2, dim = -1)
        return x * F.gelu(gates)

class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult * 2),
            GEGLU(),
            nn.Linear(dim * mult, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

The Gaussian Error Linear Unit (GELU) is an activation function used in neural networks. It was proposed by Hendrycks and Gimpel in their 2016 paper titled "Gaussian Error Linear Units (GELUs)" (https://arxiv.org/abs/1606.08415). The GELU function has gained popularity in recent years, especially in transformer-based architectures such as BERT, due to its various advantages.

Some benefits of using GELU are:

- Smoothness: GELU is a smooth, differentiable function, which makes it suitable for gradient-based optimization algorithms used in deep learning.
- Non-monotonic: GELU is non-monotonic around the origin, meaning that it can model both positive and negative relationships between variables.
- Better gradient propagation: Compared to other activation functions like ReLU, GELU tends to have better gradient propagation through the layers of the network. This is because the GELU function has non-zero gradients for both positive and negative input values, reducing the likelihood of the "dying ReLU" problem, where a neuron's gradient becomes zero and the neuron stops contributing to the learning process.
- Improved performance: In many cases, GELU has been shown to improve the performance of neural networks, especially in transformer-based architectures, compared to other activation functions like ReLU or leaky ReLU.

<img src="https://cdn-images-1.medium.com/v2/resize:fit:623/0*lpKy2FLJ8NV0QkGY"  width="500">
<img src="https://cvml-expertguide.net/wp-content/uploads/2021/08/activation_functions-1.png"  width="500">

In [23]:
get_cross_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout))

<br><br>
<h2>Stage 4 - Latent Attention</h2>
<hr>
<img src="https://i.ibb.co/Kqxkq06/perceiver-drawio2-Copy2.png"  width="500">

In [24]:
get_latent_attn = lambda: PreNorm(latent_dim, 
                                  Attention(latent_dim, heads = latent_heads, dim_head = latent_dim_head, 
                                            dropout = attn_dropout)
                                 )

<br><br>
<h2>Stage 5 - Latent Attention FFN</h2>
<hr>
<img src="https://i.ibb.co/RYXXqvm/perceiver-drawio2-Copy3.png"  width="500">

In [None]:
get_latent_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout))

In [None]:
data = rearrange(data, 'b ... d -> b (...) d')
x = repeat(self.latents, 'n d -> b n d', b = b)
# layers
for cross_attn, cross_ff, self_attns in self.layers:
    x = cross_attn(x, context = data, mask = mask) + x
    x = cross_ff(x) + x

    for self_attn, self_ff in self_attns:
        x = self_attn(x) + x
        x = self_ff(x) + x
        
if return_embeddings:
    return x

# to logits

return self.to_logits(x)

In [None]:
self.to_logits = nn.Sequential(
    Reduce('b n d -> b d', 'mean'),
    nn.LayerNorm(latent_dim),
    nn.Linear(latent_dim, num_classes)
) if final_classifier_head else nn.Identity()